コード例 #1
0
    def run_net(self):
        _rd = _read_data(data=self.data,
                         img_name=self.img_name,
                         label_name=self.label_name,
                         dataset_path=self.data_path)

        self.alpha_coeff = 1
        '''read path of the images for train, test, and validation'''
        train_data, validation_data, test_data = _rd.read_data_path()

        # ======================================
        bunch_of_images_no = 20
        sample_no = 60
        _image_class_vl = image_class(
            validation_data,
            bunch_of_images_no=bunch_of_images_no,
            is_training=0,
            patch_window=self.patch_window,
            sample_no_per_bunch=sample_no,
            label_patch_size=self.label_patchs_size,
            validation_total_sample=self.validation_samples)

        _patch_extractor_thread_vl = _extractor_thread(
            _image_class=_image_class_vl,
            patch_window=self.patch_window,
            label_patchs_size=self.label_patchs_size,
            mutex=settings.mutex,
            is_training=0,
            vl_sample_no=self.validation_samples)
        _fill_thread_vl = fill_thread(
            data=validation_data,
            _image_class=_image_class_vl,
            sample_no=sample_no,
            total_sample_no=self.validation_samples,
            label_patchs_size=self.label_patchs_size,
            mutex=settings.mutex,
            is_training=0,
            patch_extractor=_patch_extractor_thread_vl,
            fold=self.fold)

        _fill_thread_vl.start()
        _patch_extractor_thread_vl.start()
        _read_thread_vl = read_thread(
            _fill_thread_vl,
            mutex=settings.mutex,
            validation_sample_no=self.validation_samples,
            is_training=0)
        _read_thread_vl.start()
        # ======================================
        bunch_of_images_no = 20
        sample_no = 60
        _image_class = image_class(train_data,
                                   bunch_of_images_no=bunch_of_images_no,
                                   is_training=1,
                                   patch_window=self.patch_window,
                                   sample_no_per_bunch=sample_no,
                                   label_patch_size=self.label_patchs_size,
                                   validation_total_sample=0)

        patch_extractor_thread = _extractor_thread(
            _image_class=_image_class,
            patch_window=self.patch_window,
            label_patchs_size=self.label_patchs_size,
            mutex=settings.mutex,
            is_training=1)
        _fill_thread = fill_thread(train_data,
                                   _image_class,
                                   sample_no=sample_no,
                                   total_sample_no=self.sample_no,
                                   label_patchs_size=self.label_patchs_size,
                                   mutex=settings.mutex,
                                   is_training=1,
                                   patch_extractor=patch_extractor_thread,
                                   fold=self.fold)

        _fill_thread.start()
        patch_extractor_thread.start()

        _read_thread = read_thread(_fill_thread,
                                   mutex=settings.mutex,
                                   is_training=1)
        _read_thread.start()
        # ======================================
        # pre_bn=tf.placeholder(tf.float32,shape=[None,None,None,None,None])
        # image=tf.placeholder(tf.float32,shape=[self.batch_no,self.patch_window,self.patch_window,self.patch_window,1])
        # label=tf.placeholder(tf.float32,shape=[self.batch_no_validation,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size,2])
        # loss_coef=tf.placeholder(tf.float32,shape=[self.batch_no_validation,1,1,1])
        #
        # img_row1 = tf.placeholder(tf.float32, shape=[self.batch_no,self.patch_window,self.patch_window,self.patch_window, 1],name='img_row1')
        # img_row2 = tf.placeholder(tf.float32, shape=[self.batch_no,self.patch_window,self.patch_window,self.patch_window, 1],name='img_row2')
        # img_row3 = tf.placeholder(tf.float32, shape=[self.batch_no,self.patch_window,self.patch_window,self.patch_window, 1],name='img_row3')
        # img_row4 = tf.placeholder(tf.float32, shape=[self.batch_no,self.patch_window,self.patch_window,self.patch_window, 1],name='img_row4')
        # img_row5 = tf.placeholder(tf.float32, shape=[self.batch_no,self.patch_window,self.patch_window,self.patch_window, 1],name='img_row5')
        # img_row6 = tf.placeholder(tf.float32, shape=[self.batch_no,self.patch_window,self.patch_window,self.patch_window, 1],name='img_row6')
        # img_row7 = tf.placeholder(tf.float32, shape=[self.batch_no,self.patch_window,self.patch_window,self.patch_window, 1],name='img_row7')
        # img_row8 = tf.placeholder(tf.float32, shape=[self.batch_no,self.patch_window,self.patch_window,self.patch_window, 1],name='img_row8')
        #
        # #perfution 7time points original scale
        # label1 = tf.placeholder(tf.float32, shape=[self.batch_no,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size, 1],name='label1')
        # label2 = tf.placeholder(tf.float32, shape=[self.batch_no,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size, 1],name='label2')
        # label3 = tf.placeholder(tf.float32, shape=[self.batch_no,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size, 1],name='label3')
        # label4 = tf.placeholder(tf.float32, shape=[self.batch_no,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size, 1],name='label4')
        # label5 = tf.placeholder(tf.float32, shape=[self.batch_no,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size, 1],name='label5')
        # label6 = tf.placeholder(tf.float32, shape=[self.batch_no,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size, 1],name='label6')
        # # angio 7time points original scale
        # label7 = tf.placeholder(tf.float32, shape=[self.batch_no,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size, 1],name='label7')
        # label8 = tf.placeholder(tf.float32, shape=[self.batch_no,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size, 1],name='label8')
        # label9 = tf.placeholder(tf.float32, shape=[self.batch_no,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size, 1],name='label9')
        # label10 = tf.placeholder(tf.float32, shape=[self.batch_no,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size, 1],name='label10')
        # label11 = tf.placeholder(tf.float32, shape=[self.batch_no,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size, 1],name='label11')
        # label12 = tf.placeholder(tf.float32, shape=[self.batch_no,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size, 1],name='label12')
        # label13 = tf.placeholder(tf.float32, shape=[self.batch_no,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size, 1],name='label13')
        # label14 = tf.placeholder(tf.float32, shape=[self.batch_no,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size, 1],name='label14')

        img_row1 = tf.placeholder(tf.float32,
                                  shape=[None, None, None, None, 1],
                                  name='img_row1')
        img_row2 = tf.placeholder(tf.float32,
                                  shape=[None, None, None, None, 1],
                                  name='img_row2')
        img_row3 = tf.placeholder(tf.float32,
                                  shape=[None, None, None, None, 1],
                                  name='img_row3')
        img_row4 = tf.placeholder(tf.float32,
                                  shape=[None, None, None, None, 1],
                                  name='img_row4')
        img_row5 = tf.placeholder(tf.float32,
                                  shape=[None, None, None, None, 1],
                                  name='img_row5')
        img_row6 = tf.placeholder(tf.float32,
                                  shape=[None, None, None, None, 1],
                                  name='img_row6')
        img_row7 = tf.placeholder(tf.float32,
                                  shape=[None, None, None, None, 1],
                                  name='img_row7')
        img_row8 = tf.placeholder(tf.float32,
                                  shape=[None, None, None, None, 1],
                                  name='img_row8')

        label1 = tf.placeholder(tf.float32,
                                shape=[None, None, None, None, 1],
                                name='label1')
        label2 = tf.placeholder(tf.float32,
                                shape=[None, None, None, None, 1],
                                name='label2')
        label3 = tf.placeholder(tf.float32,
                                shape=[None, None, None, None, 1],
                                name='label3')
        label4 = tf.placeholder(tf.float32,
                                shape=[None, None, None, None, 1],
                                name='label4')
        label5 = tf.placeholder(tf.float32,
                                shape=[None, None, None, None, 1],
                                name='label5')
        label6 = tf.placeholder(tf.float32,
                                shape=[None, None, None, None, 1],
                                name='label6')
        label7 = tf.placeholder(tf.float32,
                                shape=[None, None, None, None, 1],
                                name='label7')
        label8 = tf.placeholder(tf.float32,
                                shape=[None, None, None, None, 1],
                                name='label8')
        label9 = tf.placeholder(tf.float32,
                                shape=[None, None, None, None, 1],
                                name='label9')
        label10 = tf.placeholder(tf.float32,
                                 shape=[None, None, None, None, 1],
                                 name='label10')
        label11 = tf.placeholder(tf.float32,
                                 shape=[None, None, None, None, 1],
                                 name='label11')
        label12 = tf.placeholder(tf.float32,
                                 shape=[None, None, None, None, 1],
                                 name='label12')
        label13 = tf.placeholder(tf.float32,
                                 shape=[None, None, None, None, 1],
                                 name='label13')
        label14 = tf.placeholder(tf.float32,
                                 shape=[None, None, None, None, 1],
                                 name='label14')

        is_training = tf.placeholder(tf.bool, name='is_training')
        input_dim = tf.placeholder(tf.int32, name='input_dim')

        forked_densenet = _forked_densenet2()

        perf_y,perf_loss_fm1,perf_loss_fm2,\
        angio_y,angio_loss_fm1,angio_loss_fm2=forked_densenet.densenet( img_row1=img_row1, img_row2=img_row2, img_row3=img_row3, img_row4=img_row4, img_row5=img_row5,
                     img_row6=img_row6, img_row7=img_row7, img_row8=img_row8,input_dim=input_dim,is_training=is_training)

        y_dirX = ((perf_y[:,
                          int(self.label_patchs_size / 2), :, :, 0,
                          np.newaxis]))
        y_dirX1 = ((perf_y[:,
                           int(self.label_patchs_size / 2), :, :, 1,
                           np.newaxis]))
        y_dirX2 = ((perf_y[:,
                           int(self.label_patchs_size / 2), :, :, 2,
                           np.newaxis]))
        y_dirX3 = ((perf_y[:,
                           int(self.label_patchs_size / 2), :, :, 3,
                           np.newaxis]))
        y_dirX4 = ((perf_y[:,
                           int(self.label_patchs_size / 2), :, :, 4,
                           np.newaxis]))
        y_dirX5 = ((perf_y[:,
                           int(self.label_patchs_size / 2), :, :, 5,
                           np.newaxis]))
        y_dirX6 = ((perf_y[:,
                           int(self.label_patchs_size / 2), :, :, 6,
                           np.newaxis]))
        y_dirX7 = ((angio_y[:,
                            int(self.label_patchs_size / 2), :, :, 0,
                            np.newaxis]))
        y_dirX8 = ((angio_y[:,
                            int(self.label_patchs_size / 2), :, :, 1,
                            np.newaxis]))
        y_dirX9 = ((angio_y[:,
                            int(self.label_patchs_size / 2), :, :, 2,
                            np.newaxis]))
        y_dirX10 = ((angio_y[:,
                             int(self.label_patchs_size / 2), :, :, 3,
                             np.newaxis]))
        y_dirX11 = ((angio_y[:,
                             int(self.label_patchs_size / 2), :, :, 4,
                             np.newaxis]))
        y_dirX12 = ((angio_y[:,
                             int(self.label_patchs_size / 2), :, :, 5,
                             np.newaxis]))
        y_dirX13 = ((angio_y[:,
                             int(self.label_patchs_size / 2), :, :, 6,
                             np.newaxis]))

        label_dirX1 = (label1[:,
                              int(self.label_patchs_size / 2), :, :, 0,
                              np.newaxis])
        label_dirX2 = (label2[:,
                              int(self.label_patchs_size / 2), :, :, 0,
                              np.newaxis])
        label_dirX3 = (label3[:,
                              int(self.label_patchs_size / 2), :, :, 0,
                              np.newaxis])
        label_dirX4 = (label4[:,
                              int(self.label_patchs_size / 2), :, :, 0,
                              np.newaxis])
        label_dirX5 = (label5[:,
                              int(self.label_patchs_size / 2), :, :, 0,
                              np.newaxis])
        label_dirX6 = (label6[:,
                              int(self.label_patchs_size / 2), :, :, 0,
                              np.newaxis])
        label_dirX7 = (label7[:,
                              int(self.label_patchs_size / 2), :, :, 0,
                              np.newaxis])
        label_dirX8 = (label8[:,
                              int(self.label_patchs_size / 2), :, :, 0,
                              np.newaxis])
        label_dirX9 = (label9[:,
                              int(self.label_patchs_size / 2), :, :, 0,
                              np.newaxis])
        label_dirX10 = (label10[:,
                                int(self.label_patchs_size / 2), :, :, 0,
                                np.newaxis])
        label_dirX11 = (label11[:,
                                int(self.label_patchs_size / 2), :, :, 0,
                                np.newaxis])
        label_dirX12 = (label12[:,
                                int(self.label_patchs_size / 2), :, :, 0,
                                np.newaxis])
        label_dirX13 = (label13[:,
                                int(self.label_patchs_size / 2), :, :, 0,
                                np.newaxis])
        label_dirX14 = (label14[:,
                                int(self.label_patchs_size / 2), :, :, 0,
                                np.newaxis])

        tf.summary.image('out0', y_dirX, 3)
        tf.summary.image('out1', y_dirX1, 3)
        tf.summary.image('out2', y_dirX2, 3)
        tf.summary.image('out3', y_dirX3, 3)
        tf.summary.image('out4', y_dirX4, 3)
        tf.summary.image('out5', y_dirX5, 3)
        tf.summary.image('out6', y_dirX6, 3)
        tf.summary.image('out7', y_dirX7, 3)
        tf.summary.image('out8', y_dirX8, 3)
        tf.summary.image('out9', y_dirX9, 3)
        tf.summary.image('out10', y_dirX10, 3)
        tf.summary.image('out11', y_dirX11, 3)
        tf.summary.image('out12', y_dirX12, 3)
        tf.summary.image('out13', y_dirX13, 3)

        tf.summary.image('groundtruth1', label_dirX1, 3)
        tf.summary.image('groundtruth2', label_dirX2, 3)
        tf.summary.image('groundtruth3', label_dirX3, 3)
        tf.summary.image('groundtruth4', label_dirX4, 3)
        tf.summary.image('groundtruth5', label_dirX5, 3)
        tf.summary.image('groundtruth6', label_dirX6, 3)
        tf.summary.image('groundtruth7', label_dirX7, 3)
        tf.summary.image('groundtruth8', label_dirX8, 3)
        tf.summary.image('groundtruth9', label_dirX9, 3)
        tf.summary.image('groundtruth10', label_dirX10, 3)
        tf.summary.image('groundtruth11', label_dirX11, 3)
        tf.summary.image('groundtruth12', label_dirX12, 3)
        tf.summary.image('groundtruth13', label_dirX13, 3)
        tf.summary.image('groundtruth14', label_dirX14, 3)

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)

        log_extttt = ''
        train_writer = tf.summary.FileWriter(self.LOGDIR + '/train' +
                                             log_extttt,
                                             graph=tf.get_default_graph())
        validation_writer = tf.summary.FileWriter(self.LOGDIR + '/validation' +
                                                  log_extttt,
                                                  graph=sess.graph)
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=1000)

        utils.backup_code(self.LOGDIR)
        '''AdamOptimizer:'''
        with tf.name_scope('MSE_perf_angio'):
            MSE_perf_angio = self.loss_instance.MSE_perf_angio(
                label1=label1,
                label2=label2,
                label3=label3,
                label4=label4,
                label5=label5,
                label6=label6,
                label7=label7,
                label8=label8,
                label9=label9,
                label10=label10,
                label11=label11,
                label12=label12,
                label13=label13,
                label14=label14,
                logit1=perf_y[:, :, :, :, 0, np.newaxis],
                logit2=perf_y[:, :, :, :, 1, np.newaxis],
                logit3=perf_y[:, :, :, :, 2, np.newaxis],
                logit4=perf_y[:, :, :, :, 3, np.newaxis],
                logit5=perf_y[:, :, :, :, 4, np.newaxis],
                logit6=perf_y[:, :, :, :, 5, np.newaxis],
                logit7=perf_y[:, :, :, :, 6, np.newaxis],
                logit8=angio_y[:, :, :, :, 0, np.newaxis],
                logit9=angio_y[:, :, :, :, 1, np.newaxis],
                logit10=angio_y[:, :, :, :, 2, np.newaxis],
                logit11=angio_y[:, :, :, :, 3, np.newaxis],
                logit12=angio_y[:, :, :, :, 4, np.newaxis],
                logit13=angio_y[:, :, :, :, 5, np.newaxis],
                logit14=angio_y[:, :, :, :, 6, np.newaxis],
                perf_loss_fm1=perf_loss_fm1,
                perf_loss_fm2=perf_loss_fm2,
                angio_loss_fm1=angio_loss_fm1,
                angio_loss_fm2=angio_loss_fm2,
            )
            cost = tf.reduce_mean(MSE_perf_angio, name="cost")
        tf.summary.scalar("cost", MSE_perf_angio)

        extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(extra_update_ops):
            optimizer = tf.train.AdamOptimizer(
                self.learning_rate).minimize(cost)

        sess.run(tf.global_variables_initializer())
        logging.debug('total number of variables %s' % (np.sum([
            np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()
        ])))

        summ = tf.summary.merge_all()
        loadModel = 0
        point = 0
        itr1 = 0
        if loadModel:
            chckpnt_dir = ''
            ckpt = tf.train.get_checkpoint_state(chckpnt_dir)
            saver.restore(sess, ckpt.model_checkpoint_path)
            point = np.int16(
                ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
            itr1 = point
        '''loop for epochs'''
        for epoch in range(self.total_epochs):
            while self.no_sample_per_each_itr * int(
                    point / self.no_sample_per_each_itr) < self.sample_no:
                print('0')
                print("epoch #: %d" % (epoch))
                startTime = time.time()

                step = 0
                # =============validation================
                if itr1 % self.display_validation_step == 0:
                    '''Validation: '''
                    loss_validation = 0
                    acc_validation = 0
                    validation_step = 0
                    dsc_validation = 0

                while (validation_step * self.batch_no_validation <
                       settings.validation_totalimg_patch):

                    [crush, noncrush, perf,
                     angio] = _image_class_vl.return_patches_vl(
                         validation_step * self.batch_no_validation,
                         (validation_step + 1) * self.batch_no_validation,
                         is_tr=False)
                    if (len(angio) < self.batch_no_validation):
                        _read_thread_vl.resume()
                        time.sleep(0.5)
                        continue
                    # continue

                    [loss_vali] = sess.run(
                        [cost],
                        feed_dict={
                            img_row1: crush[:, 0, :, :, :, :],
                            img_row2: noncrush[:, 1, :, :, :, :],
                            img_row3: crush[:, 2, :, :, :, :],
                            img_row4: noncrush[:, 3, :, :, :, :],
                            img_row5: crush[:, 4, :, :, :, :],
                            img_row6: noncrush[:, 5, :, :, :, :],
                            img_row7: crush[:, 6, :, :, :, :],
                            img_row8: noncrush[:, 7, :, :, :, :],
                            label1: perf[:, 0, :, :, :, :],
                            label2: perf[:, 1, :, :, :, :],
                            label3: perf[:, 2, :, :, :, :],
                            label4: perf[:, 3, :, :, :, :],
                            label5: perf[:, 4, :, :, :, :],
                            label6: perf[:, 5, :, :, :, :],
                            label7: perf[:, 6, :, :, :, :],
                            label8: angio[:, 0, :, :, :, :],
                            label9: angio[:, 1, :, :, :, :],
                            label10: angio[:, 2, :, :, :, :],
                            label11: angio[:, 3, :, :, :, :],
                            label12: angio[:, 4, :, :, :, :],
                            label13: angio[:, 5, :, :, :, :],
                            label14: angio[:, 6, :, :, :, :],
                            is_training: False,
                            input_dim: self.patch_window,
                        })
                    loss_validation += loss_vali

                    validation_step += 1
                    if np.isnan(dsc_validation) or np.isnan(
                            loss_validation) or np.isnan(acc_validation):
                        print('nan problem')
                    process = psutil.Process(os.getpid())
                    print(
                        '%d - > %d:   loss_validation: %f, memory_percent: %4s'
                        % (
                            validation_step,
                            validation_step * self.batch_no_validation,
                            loss_vali,
                            str(process.memory_percent()),
                        ))

                settings.queue_isready_vl = False
                acc_validation = acc_validation / (validation_step)
                loss_validation = loss_validation / (validation_step)
                dsc_validation = dsc_validation / (validation_step)

                if np.isnan(dsc_validation) or np.isnan(
                        loss_validation) or np.isnan(acc_validation):
                    print('nan problem')
                _fill_thread_vl.kill_thread()
                print(
                    '******Validation, step: %d , accuracy: %.4f, loss: %f*******'
                    % (itr1, acc_validation, loss_validation))
                [sum_validation] = sess.run(
                    [summ],
                    feed_dict={
                        img_row1: crush[:, 0, :, :, :, :],
                        img_row2: noncrush[:, 1, :, :, :, :],
                        img_row3: crush[:, 2, :, :, :, :],
                        img_row4: noncrush[:, 3, :, :, :, :],
                        img_row5: crush[:, 4, :, :, :, :],
                        img_row6: noncrush[:, 5, :, :, :, :],
                        img_row7: crush[:, 6, :, :, :, :],
                        img_row8: noncrush[:, 7, :, :, :, :],
                        label1: perf[:, 0, :, :, :, :],
                        label2: perf[:, 1, :, :, :, :],
                        label3: perf[:, 2, :, :, :, :],
                        label4: perf[:, 3, :, :, :, :],
                        label5: perf[:, 4, :, :, :, :],
                        label6: perf[:, 5, :, :, :, :],
                        label7: perf[:, 6, :, :, :, :],
                        label8: angio[:, 0, :, :, :, :],
                        label9: angio[:, 1, :, :, :, :],
                        label10: angio[:, 2, :, :, :, :],
                        label11: angio[:, 3, :, :, :, :],
                        label12: angio[:, 4, :, :, :, :],
                        label13: angio[:, 5, :, :, :, :],
                        label14: angio[:, 6, :, :, :, :],
                        is_training: False,
                        input_dim: self.patch_window,
                    })
                validation_writer.add_summary(sum_validation, point)
                print('end of validation---------%d' % (point))
                '''loop for training batches'''
                while (step * self.batch_no < self.no_sample_per_each_itr):

                    # [train_CT_image_patchs, train_GTV_label, train_Penalize_patch,loss_coef_weights] = _image_class.return_patches( self.batch_no)

                    [crush, noncrush, perf,
                     angio] = _image_class.return_patches_tr(self.batch_no)

                    if (len(angio) < self.batch_no):
                        time.sleep(0.5)
                        _read_thread.resume()
                        continue

                    [
                        loss_train1,
                        optimizing,
                        out,
                    ] = sess.run(
                        [
                            cost,
                            optimizer,
                            perf_y,
                        ],
                        feed_dict={
                            img_row1: crush[:, 0, :, :, :, :],
                            img_row2: noncrush[:, 1, :, :, :, :],
                            img_row3: crush[:, 2, :, :, :, :],
                            img_row4: noncrush[:, 3, :, :, :, :],
                            img_row5: crush[:, 4, :, :, :, :],
                            img_row6: noncrush[:, 5, :, :, :, :],
                            img_row7: crush[:, 6, :, :, :, :],
                            img_row8: noncrush[:, 7, :, :, :, :],
                            label1: perf[:, 0, :, :, :, :],
                            label2: perf[:, 1, :, :, :, :],
                            label3: perf[:, 2, :, :, :, :],
                            label4: perf[:, 3, :, :, :, :],
                            label5: perf[:, 4, :, :, :, :],
                            label6: perf[:, 5, :, :, :, :],
                            label7: perf[:, 6, :, :, :, :],
                            label8: angio[:, 0, :, :, :, :],
                            label9: angio[:, 1, :, :, :, :],
                            label10: angio[:, 2, :, :, :, :],
                            label11: angio[:, 3, :, :, :, :],
                            label12: angio[:, 4, :, :, :, :],
                            label13: angio[:, 5, :, :, :, :],
                            label14: angio[:, 6, :, :, :, :],
                            is_training: False,
                            input_dim: self.patch_window,
                        })

                    self.x_hist = self.x_hist + 1

                    [sum_train] = sess.run(
                        [summ],
                        feed_dict={
                            img_row1: crush[:, 0, :, :, :, :],
                            img_row2: noncrush[:, 1, :, :, :, :],
                            img_row3: crush[:, 2, :, :, :, :],
                            img_row4: noncrush[:, 3, :, :, :, :],
                            img_row5: crush[:, 4, :, :, :, :],
                            img_row6: noncrush[:, 5, :, :, :, :],
                            img_row7: crush[:, 6, :, :, :, :],
                            img_row8: noncrush[:, 7, :, :, :, :],
                            label1: perf[:, 0, :, :, :, :],
                            label2: perf[:, 1, :, :, :, :],
                            label3: perf[:, 2, :, :, :, :],
                            label4: perf[:, 3, :, :, :, :],
                            label5: perf[:, 4, :, :, :, :],
                            label6: perf[:, 5, :, :, :, :],
                            label7: perf[:, 6, :, :, :, :],
                            label8: angio[:, 0, :, :, :, :],
                            label9: angio[:, 1, :, :, :, :],
                            label10: angio[:, 2, :, :, :, :],
                            label11: angio[:, 3, :, :, :, :],
                            label12: angio[:, 4, :, :, :, :],
                            label13: angio[:, 5, :, :, :, :],
                            label14: angio[:, 6, :, :, :, :],
                            is_training: False,
                            input_dim: self.patch_window,
                        })
                    train_writer.add_summary(sum_train, point)
                    step = step + 1

                    process = psutil.Process(os.getpid())
                    print(
                        'point: %d, step*self.batch_no:%f , LR: %.15f, loss_train1:%f,memory_percent: %4s'
                        % (int(
                            (point)), step * self.batch_no, self.learning_rate,
                           loss_train1, str(process.memory_percent())))
                    point = int(
                        (point)
                    )  #(self.no_sample_per_each_itr/self.batch_no)*itr1+step

                    if point % 100 == 0:
                        '''saveing model inter epoch'''
                        chckpnt_path = os.path.join(
                            self.chckpnt_dir,
                            ('unet_inter_epoch%d_point%d.ckpt' %
                             (epoch, point)))
                        saver.save(sess, chckpnt_path, global_step=point)

                    itr1 = itr1 + 1
                    point = point + 1

            endTime = time.time()
            #==============end of epoch:
            '''saveing model after each epoch'''
            chckpnt_path = os.path.join(self.chckpnt_dir, 'unet.ckpt')
            saver.save(sess, chckpnt_path, global_step=epoch)

            print("End of epoch----> %d, elapsed time: %d" %
                  (epoch, endTime - startTime))
コード例 #2
0
                     img_name=img_name,
                     label_name=label_name,
                     dataset_path=data_path)

    train_data, validation_data, test_data = _rd.read_data_path(fold=fold)
    input_cube_size = 47
    gt_cube_size = input_cube_size
    test_vali = 0
    if test_vali == 1:
        test_set = validation_data
    else:
        test_set = test_data
    img_class = image_class(test_set,
                            bunch_of_images_no=1,
                            is_training=1,
                            patch_window=input_cube_size,
                            sample_no_per_bunch=1,
                            label_patch_size=1,
                            validation_total_sample=0)
    sess = tf.Session()
    # tf.initializers.global_variables()

    loss = 0
    for img_indx in range(len(test_set)):

        crush_noncrush_perf_angio = img_class.read_image_for_test(
            test_set, img_indx)
        for j in range(7):
            for k in range(2, 4):
                img_size = np.shape(crush_noncrush_perf_angio[k][j])[0]
                input = crush_noncrush_perf_angio[k][j][
コード例 #3
0
    def run_net(self,loadModel = 0):

        # pre_bn=tf.placeholder(tf.float32,shape=[None,None,None,None,None])
        # image=tf.placeholder(tf.float32,shape=[self.batch_no,self.patch_window,self.patch_window,self.patch_window,1])
        # label=tf.placeholder(tf.float32,shape=[self.batch_no_validation,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size,2])
        # loss_coef=tf.placeholder(tf.float32,shape=[self.batch_no_validation,1,1,1])
        # ===================================================================================
        _rd = _read_data(data=self.data,
                         img_name=self.img_name, label_name=self.label_name,
                         dataset_path=self.data_path,reverse=self.newdataset)

        self.alpha_coeff = 1
        '''read path of the images for train, test, and validation'''
        train_data, validation_data, test_data = _rd.read_data_path()

        # ======================================
        bunch_of_images_no = 20
        sample_no = 60
        _image_class_vl = image_class(validation_data,
                                      bunch_of_images_no=bunch_of_images_no,
                                      is_training=0,
                                      patch_window=self.patch_window,
                                      sample_no_per_bunch=sample_no,
                                      label_patch_size=self.label_patchs_size,
                                      validation_total_sample=self.validation_samples)

        _patch_extractor_thread_vl = _extractor_thread(_image_class=_image_class_vl,
                                                       patch_window=self.patch_window,
                                                       label_patchs_size=self.label_patchs_size,

                                                       mutex=settings.mutex,
                                                       is_training=0,
                                                       vl_sample_no=self.validation_samples
                                                       )
        _fill_thread_vl = fill_thread(data=validation_data,
                                      _image_class=_image_class_vl,
                                      sample_no=sample_no,
                                      total_sample_no=self.validation_samples,
                                      label_patchs_size=self.label_patchs_size,
                                      mutex=settings.mutex,
                                      is_training=0,
                                      patch_extractor=_patch_extractor_thread_vl,
                                      fold=self.fold)

        _fill_thread_vl.start()
        _patch_extractor_thread_vl.start()
        _read_thread_vl = read_thread(_fill_thread_vl,
                                      mutex=settings.mutex,
                                      validation_sample_no=self.validation_samples,
                                      is_training=0)
        _read_thread_vl.start()
        # ======================================
        bunch_of_images_no = 20
        sample_no = 60
        _image_class = image_class(train_data
                                   , bunch_of_images_no=bunch_of_images_no,
                                   is_training=1,
                                   patch_window=self.patch_window,
                                   sample_no_per_bunch=sample_no,
                                   label_patch_size=self.label_patchs_size,
                                   validation_total_sample=0)

        patch_extractor_thread = _extractor_thread(_image_class=_image_class,
                                                   patch_window=self.patch_window,
                                                   label_patchs_size=self.label_patchs_size,
                                                   mutex=settings.mutex, is_training=1)
        _fill_thread = fill_thread(train_data,
                                   _image_class,
                                   sample_no=sample_no, total_sample_no=self.sample_no,
                                   label_patchs_size=self.label_patchs_size,
                                   mutex=settings.mutex,
                                   is_training=1,
                                   patch_extractor=patch_extractor_thread,
                                   fold=self.fold)
        #
        _fill_thread.start()
        patch_extractor_thread.start()

        _read_thread = read_thread(_fill_thread, mutex=settings.mutex, is_training=1)
        _read_thread.start()
        # ===================================================================================
        img_row1 = tf.placeholder(tf.float32,
                                  shape=[self.batch_no, self.patch_window, self.patch_window, self.patch_window, 1],name='img_row1')
        img_row2 = tf.placeholder(tf.float32,
                                  shape=[self.batch_no, self.patch_window, self.patch_window, self.patch_window, 1],name='img_row2')
        img_row3 = tf.placeholder(tf.float32,
                                  shape=[self.batch_no, self.patch_window, self.patch_window, self.patch_window, 1],name='img_row3')
        img_row4 = tf.placeholder(tf.float32,
                                  shape=[self.batch_no, self.patch_window, self.patch_window, self.patch_window, 1],name='img_row4')
        img_row5 = tf.placeholder(tf.float32,
                                  shape=[self.batch_no, self.patch_window, self.patch_window, self.patch_window, 1],name='img_row5')
        img_row6 = tf.placeholder(tf.float32,
                                  shape=[self.batch_no, self.patch_window, self.patch_window, self.patch_window, 1],name='img_row6')
        img_row7 = tf.placeholder(tf.float32,
                                  shape=[self.batch_no, self.patch_window, self.patch_window, self.patch_window, 1],name='img_row7')
        img_row8 = tf.placeholder(tf.float32,
                                  shape=[self.batch_no, self.patch_window, self.patch_window, self.patch_window, 1],name='img_row8')

        mri_ph = tf.placeholder(tf.float32,
                                  shape=[self.batch_no, self.patch_window, self.patch_window, self.patch_window, 1],name='mri')

        label1 = tf.placeholder(tf.float32, shape=[self.batch_no, self.label_patchs_size, self.label_patchs_size,
                                                   self.label_patchs_size, 1], name='label1')
        label2 = tf.placeholder(tf.float32, shape=[self.batch_no, self.label_patchs_size, self.label_patchs_size,
                                                   self.label_patchs_size, 1], name='label2')
        label3 = tf.placeholder(tf.float32, shape=[self.batch_no, self.label_patchs_size, self.label_patchs_size,
                                                   self.label_patchs_size, 1], name='label3')
        label4 = tf.placeholder(tf.float32, shape=[self.batch_no, self.label_patchs_size, self.label_patchs_size,
                                                   self.label_patchs_size, 1], name='label4')
        label5 = tf.placeholder(tf.float32, shape=[self.batch_no, self.label_patchs_size, self.label_patchs_size,
                                                   self.label_patchs_size, 1], name='label5')
        label6 = tf.placeholder(tf.float32, shape=[self.batch_no, self.label_patchs_size, self.label_patchs_size,
                                                   self.label_patchs_size, 1], name='label6')
        label7 = tf.placeholder(tf.float32, shape=[self.batch_no, self.label_patchs_size, self.label_patchs_size,
                                                   self.label_patchs_size, 1], name='label7')
        label8 = tf.placeholder(tf.float32, shape=[self.batch_no, self.label_patchs_size, self.label_patchs_size,
                                                   self.label_patchs_size, 1], name='label8')
        label9 = tf.placeholder(tf.float32, shape=[self.batch_no, self.label_patchs_size, self.label_patchs_size,
                                                   self.label_patchs_size, 1], name='label9')
        label10 = tf.placeholder(tf.float32, shape=[self.batch_no, self.label_patchs_size, self.label_patchs_size,
                                                    self.label_patchs_size, 1], name='label10')
        label11 = tf.placeholder(tf.float32, shape=[self.batch_no, self.label_patchs_size, self.label_patchs_size,
                                                    self.label_patchs_size, 1], name='label11')
        label12 = tf.placeholder(tf.float32, shape=[self.batch_no, self.label_patchs_size, self.label_patchs_size,
                                                    self.label_patchs_size, 1], name='label12')
        label13 = tf.placeholder(tf.float32, shape=[self.batch_no, self.label_patchs_size, self.label_patchs_size,
                                                    self.label_patchs_size, 1], name='label13')
        label14 = tf.placeholder(tf.float32, shape=[self.batch_no, self.label_patchs_size, self.label_patchs_size,
                                                    self.label_patchs_size, 1], name='label14')

        # loss_placeholder = tf.placeholder(tf.float32,
        #                                   shape=[self.batch_no, self.label_patchs_size, self.label_patchs_size,
        #                                          self.label_patchs_size, 1])

        # img_row1 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='img_row1')
        # img_row2 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='img_row2')
        # img_row3 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='img_row3')
        # img_row4 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='img_row4')
        # img_row5 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='img_row5')
        # img_row6 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='img_row6')
        # img_row7 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='img_row7')
        # img_row8 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='img_row8')
        #
        # label1 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='label1')
        # label2 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='label2')
        # label3 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='label3')
        # label4 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='label4')
        # label5 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='label5')
        # label6 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='label6')
        # label7 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='label7')
        # label8 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='label8')
        # label9 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='label9')
        # label10 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='label10')
        # label11 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='label11')
        # label12 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='label12')
        # label13 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='label13')
        # label14 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='label14')

        is_training = tf.placeholder(tf.bool, name='is_training')
        input_dim = tf.placeholder(tf.int32, name='input_dim')

        # unet=_unet()
        multi_stage_densenet = _multi_stage_densenet()

        y,loss_upsampling11,loss_upsampling22 = multi_stage_densenet.multi_stage_densenet(img_row1=img_row1,
                              img_row2=img_row2,
                              img_row3=img_row3,
                              img_row4=img_row4,
                              img_row5=img_row5,
                              img_row6=img_row6,
                              img_row7=img_row7,
                              img_row8=img_row8,
                              input_dim=input_dim,
                              mri = mri_ph,
                              is_training=is_training)
        # self.vgg = vgg_feature_maker()
        # self.vgg_y0 = self.vgg.feed_img(loss_placeholder)
        # self.vgg_y0 = self.vgg.feed_img(y[:, :, :, :, 0]).copy()
        # self.vgg_y1 = self.vgg.feed_img(y[:, :, :, :, 1]).copy()
        # self.vgg_y2 = self.vgg.feed_img(y[:, :, :, :, 2]).copy()
        # self.vgg_y3 = self.vgg.feed_img(y[:, :, :, :, 3]).copy()
        # self.vgg_y4 = self.vgg.feed_img(y[:, :, :, :, 4]).copy()
        # self.vgg_y5 = self.vgg.feed_img(y[:, :, :, :, 5]).copy()
        # self.vgg_y6 = self.vgg.feed_img(y[:, :, :, :, 6]).copy()
        # self.vgg_y7 = self.vgg.feed_img(y[:, :, :, :, 7])
        # self.vgg_y8 = self.vgg.feed_img(y[:, :, :, :, 8])
        # self.vgg_y9 = self.vgg.feed_img(y[:, :, :, :, 9])
        # self.vgg_y10 = self.vgg.feed_img(y[:, :, :, :, 10])
        # self.vgg_y11 = self.vgg.feed_img(y[:, :, :, :, 11])
        # self.vgg_y12 = self.vgg.feed_img(y[:, :, :, :, 12])
        # self.vgg_y13 = self.vgg.feed_img(y[:, :, :, :, 13])

        # self.vgg_label0 = self.vgg.feed_img(label1[:,:,:,:,0]).copy()
        # self.vgg_label1 = self.vgg.feed_img(label2[:,:,:,:,0]).copy()
        # self.vgg_label2 = self.vgg.feed_img(label3[:,:,:,:,0]).copy()
        # self.vgg_label3 = self.vgg.feed_img(label4[:,:,:,:,0]).copy()
        # self.vgg_label4 = self.vgg.feed_img(label5[:,:,:,:,0]).copy()
        # self.vgg_label5 = self.vgg.feed_img(label6[:,:,:,:,0]).copy()
        # self.vgg_label6 = self.vgg.feed_img(label7[:,:,:,:,0]).copy()
        # self.vgg_label7 = self.vgg.feed_img(label8[:,:,:,:,0])
        # self.vgg_label8 = self.vgg.feed_img(label9[:,:,:,:,0])
        # self.vgg_label9 = self.vgg.feed_img(label10[:,:,:,:,0])
        # self.vgg_label10 = self.vgg.feed_img(label11[:,:,:,:,0])
        # self.vgg_label11 = self.vgg.feed_img(label12[:,:,:,:,0])
        # self.vgg_label12 = self.vgg.feed_img(label13[:,:,:,:,0])
        # self.vgg_label13 = self.vgg.feed_img(label14[:,:,:,:,0])

        y_dirX = ((y[:, int(self.label_patchs_size / 2), :, :, 0, np.newaxis]))
        y_dirX1 = ((y[:, int(self.label_patchs_size / 2), :, :, 1, np.newaxis]))
        y_dirX2 = ((y[:, int(self.label_patchs_size / 2), :, :, 2, np.newaxis]))
        y_dirX3 = ((y[:, int(self.label_patchs_size / 2), :, :, 3, np.newaxis]))
        y_dirX4 = ((y[:, int(self.label_patchs_size / 2), :, :, 4, np.newaxis]))
        y_dirX5 = ((y[:, int(self.label_patchs_size / 2), :, :, 5, np.newaxis]))
        y_dirX6 = ((y[:, int(self.label_patchs_size / 2), :, :, 6, np.newaxis]))
        y_dirX7 = ((y[:, int(self.label_patchs_size / 2), :, :, 7, np.newaxis]))
        y_dirX8 = ((y[:, int(self.label_patchs_size / 2), :, :, 8, np.newaxis]))
        y_dirX9 = ((y[:, int(self.label_patchs_size / 2), :, :, 9, np.newaxis]))
        y_dirX10 = ((y[:, int(self.label_patchs_size / 2), :, :, 10, np.newaxis]))
        y_dirX11 = ((y[:, int(self.label_patchs_size / 2), :, :, 11, np.newaxis]))
        y_dirX12 = ((y[:, int(self.label_patchs_size / 2), :, :, 12, np.newaxis]))
        y_dirX13 = ((y[:, int(self.label_patchs_size / 2), :, :, 13, np.newaxis]))

        label_dirX1 = (label1[:, int(self.label_patchs_size / 2), :, :, 0, np.newaxis])
        label_dirX2 = (label2[:, int(self.label_patchs_size / 2), :, :, 0, np.newaxis])
        label_dirX3 = (label3[:, int(self.label_patchs_size / 2), :, :, 0, np.newaxis])
        label_dirX4 = (label4[:, int(self.label_patchs_size / 2), :, :, 0, np.newaxis])
        label_dirX5 = (label5[:, int(self.label_patchs_size / 2), :, :, 0, np.newaxis])
        label_dirX6 = (label6[:, int(self.label_patchs_size / 2), :, :, 0, np.newaxis])
        label_dirX7 = (label7[:, int(self.label_patchs_size / 2), :, :, 0, np.newaxis])
        label_dirX8 = (label8[:, int(self.label_patchs_size / 2), :, :, 0, np.newaxis])
        label_dirX9 = (label9[:, int(self.label_patchs_size / 2), :, :, 0, np.newaxis])
        label_dirX10 = (label10[:, int(self.label_patchs_size / 2), :, :, 0, np.newaxis])
        label_dirX11 = (label11[:, int(self.label_patchs_size / 2), :, :, 0, np.newaxis])
        label_dirX12 = (label12[:, int(self.label_patchs_size / 2), :, :, 0, np.newaxis])
        label_dirX13 = (label13[:, int(self.label_patchs_size / 2), :, :, 0, np.newaxis])
        label_dirX14 = (label14[:, int(self.label_patchs_size / 2), :, :, 0, np.newaxis])

        tf.summary.image('out0', y_dirX, 3)
        tf.summary.image('out1', y_dirX1, 3)
        tf.summary.image('out2', y_dirX2, 3)
        tf.summary.image('out3', y_dirX3, 3)
        tf.summary.image('out4', y_dirX4, 3)
        tf.summary.image('out5', y_dirX5, 3)
        tf.summary.image('out6', y_dirX6, 3)
        tf.summary.image('out7', y_dirX7, 3)
        tf.summary.image('out8', y_dirX8, 3)
        tf.summary.image('out9', y_dirX9, 3)
        tf.summary.image('out10', y_dirX10, 3)
        tf.summary.image('out11', y_dirX11, 3)
        tf.summary.image('out12', y_dirX12, 3)
        tf.summary.image('out13', y_dirX13, 3)

        tf.summary.image('groundtruth1', label_dirX1, 3)
        tf.summary.image('groundtruth2', label_dirX2, 3)
        tf.summary.image('groundtruth3', label_dirX3, 3)
        tf.summary.image('groundtruth4', label_dirX4, 3)
        tf.summary.image('groundtruth5', label_dirX5, 3)
        tf.summary.image('groundtruth6', label_dirX6, 3)
        tf.summary.image('groundtruth7', label_dirX7, 3)
        tf.summary.image('groundtruth8', label_dirX8, 3)
        tf.summary.image('groundtruth9', label_dirX9, 3)
        tf.summary.image('groundtruth10', label_dirX10, 3)
        tf.summary.image('groundtruth11', label_dirX11, 3)
        tf.summary.image('groundtruth12', label_dirX12, 3)
        tf.summary.image('groundtruth13', label_dirX13, 3)
        tf.summary.image('groundtruth14', label_dirX14, 3)

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)

        train_writer = tf.summary.FileWriter(self.LOGDIR + '/train' , graph=tf.get_default_graph())
        validation_writer = tf.summary.FileWriter(self.LOGDIR + '/validation' , graph=sess.graph)
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=10000)
        if loadModel==0:
            utils.backup_code(self.LOGDIR)
        labels = []
        labels.append(label1)
        labels.append(label2)
        labels.append(label3)
        labels.append(label4)
        labels.append(label5)
        labels.append(label6)
        labels.append(label7)
        labels.append(label8)
        labels.append(label9)
        labels.append(label10)
        labels.append(label11)
        labels.append(label12)
        labels.append(label13)
        labels.append(label14)

        logits = []
        logits.append(y[:, :, :, :, 0, np.newaxis])
        logits.append(y[:, :, :, :, 1, np.newaxis])
        logits.append(y[:, :, :, :, 2, np.newaxis])
        logits.append(y[:, :, :, :, 3, np.newaxis])
        logits.append(y[:, :, :, :, 4, np.newaxis])
        logits.append(y[:, :, :, :, 5, np.newaxis])
        logits.append(y[:, :, :, :, 6, np.newaxis])

        logits.append(y[:, :, :, :, 7, np.newaxis])
        logits.append(y[:, :, :, :, 8, np.newaxis])
        logits.append(y[:, :, :, :, 9, np.newaxis])
        logits.append(y[:, :, :, :, 10, np.newaxis])
        logits.append(y[:, :, :, :, 11, np.newaxis])
        logits.append(y[:, :, :, :, 12, np.newaxis])
        logits.append(y[:, :, :, :, 13, np.newaxis])

        stage1 = []
        stage1.append(loss_upsampling11[:, :, :, :, 0, np.newaxis])
        stage1.append(loss_upsampling11[:, :, :, :, 1, np.newaxis])
        stage1.append(loss_upsampling11[:, :, :, :, 2, np.newaxis])
        stage1.append(loss_upsampling11[:, :, :, :, 3, np.newaxis])
        stage1.append(loss_upsampling11[:, :, :, :, 4, np.newaxis])
        stage1.append(loss_upsampling11[:, :, :, :, 5, np.newaxis])
        stage1.append(loss_upsampling11[:, :, :, :, 6, np.newaxis])
        stage1.append(loss_upsampling11[:, :, :, :, 7, np.newaxis])
        stage1.append(loss_upsampling11[:, :, :, :, 8, np.newaxis])
        stage1.append(loss_upsampling11[:, :, :, :, 9, np.newaxis])
        stage1.append(loss_upsampling11[:, :, :, :, 10, np.newaxis])
        stage1.append(loss_upsampling11[:, :, :, :, 11, np.newaxis])
        stage1.append(loss_upsampling11[:, :, :, :, 12, np.newaxis])
        stage1.append(loss_upsampling11[:, :, :, :, 13, np.newaxis])

        stage2 = []
        stage2.append(loss_upsampling22[:, :, :, :, 0, np.newaxis])
        stage2.append(loss_upsampling22[:, :, :, :, 1, np.newaxis])
        stage2.append(loss_upsampling22[:, :, :, :, 2, np.newaxis])
        stage2.append(loss_upsampling22[:, :, :, :, 3, np.newaxis])
        stage2.append(loss_upsampling22[:, :, :, :, 4, np.newaxis])
        stage2.append(loss_upsampling22[:, :, :, :, 5, np.newaxis])
        stage2.append(loss_upsampling22[:, :, :, :, 6, np.newaxis])
        stage2.append(loss_upsampling22[:, :, :, :, 7, np.newaxis])
        stage2.append(loss_upsampling22[:, :, :, :, 8, np.newaxis])
        stage2.append(loss_upsampling22[:, :, :, :, 9, np.newaxis])
        stage2.append(loss_upsampling22[:, :, :, :, 10, np.newaxis])
        stage2.append(loss_upsampling22[:, :, :, :, 11, np.newaxis])
        stage2.append(loss_upsampling22[:, :, :, :, 12, np.newaxis])
        stage2.append(loss_upsampling22[:, :, :, :, 13, np.newaxis])



        with tf.name_scope('Loss'):
            loss_dic = self.loss_instance.loss_selector('Multistage_ssim_perf_angio_loss',
                                                        labels=labels,logits=logits,
                                                        stage1=stage1,
                                                        stage2=stage2)
            cost = tf.reduce_mean(loss_dic["loss"], name="cost")

        with tf.variable_scope("summary"):
            tf.summary.scalar("Loss/loss", cost)

        # ============================================
        all_loss = tf.placeholder(tf.float32, name='loss')
        with tf.name_scope('validation'):
            ave_loss = all_loss
        tf.summary.scalar("ave_loss", ave_loss)
        '''AdamOptimizer:'''
        extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(extra_update_ops):
            optimizer = tf.train.AdamOptimizer(self.learning_rate).minimize(cost)

        sess.run(tf.global_variables_initializer())
        logging.debug('total number of variables %s' % (
            np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])))

        summ = tf.summary.merge_all()

        point = 0
        if loadModel:
            chckpnt_dir = '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/sythesize_code/ASL_LOG/MRI_in/experiment-2/unet_checkpoints/'
            ckpt = tf.train.get_checkpoint_state(chckpnt_dir)
            saver.restore(sess, ckpt.model_checkpoint_path)
            point = (ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
            point = int(int(point) / self.display_train_step) * self.display_train_step

        '''loop for epochs'''
        for epoch in range(self.total_epochs):
            while self.no_sample_per_each_itr * int(point / self.no_sample_per_each_itr) < self.sample_no:
                print('0')
                print("epoch #: %d" % (epoch))
                startTime = time.time()

                step = 0
                # =============validation================
                if point % self.display_validation_step == 0:
                    '''Validation: '''
                    loss_validation = 0
                    loss_validation_p_ssim=0
                    loss_validation_p_perceptual=0
                    acc_validation = 0
                    validation_step = 0
                    dsc_validation = 0


                while (validation_step * self.batch_no_validation < settings.validation_totalimg_patch):

                    [crush, noncrush, perf, angio,mri,segmentation] = _image_class_vl.return_patches_vl(
                        validation_step * self.batch_no_validation,
                        (validation_step + 1) * self.batch_no_validation, is_tr=False
                        )
                    if (len(segmentation) < self.batch_no_validation):
                        _read_thread_vl.resume()
                        time.sleep(0.5)
                        continue
                    # continue

                    [loss_vali] = sess.run([cost],
                                  feed_dict={img_row1: crush[:, 0, :, :, :, :],
                                             img_row2: noncrush[:, 1, :, :, :, :],
                                             img_row3: crush[:, 2, :, :, :, :],
                                             img_row4: noncrush[:, 3, :, :, :, :],
                                             img_row5: crush[:, 4, :, :, :, :],
                                             img_row6: noncrush[:, 5, :, :, :, :],
                                             img_row7: crush[:, 6, :, :, :, :],
                                             img_row8: noncrush[:, 7, :, :, :, :],
                                             mri_ph: mri[:, 0, :, :, :, :],
                                             label1: perf[:, 0, :, :, :, :],
                                             label2: perf[:, 1, :, :, :, :],
                                             label3: perf[:, 2, :, :, :, :],
                                             label4: perf[:, 3, :, :, :, :],
                                             label5: perf[:, 4, :, :, :, :],
                                             label6: perf[:, 5, :, :, :, :],
                                             label7: perf[:, 6, :, :, :, :],
                                             label8: angio[:, 0, :, :, :, :],
                                             label9: angio[:, 1, :, :, :, :],
                                             label10: angio[:, 2, :, :, :, :],
                                             label11: angio[:, 3, :, :, :, :],
                                             label12: angio[:, 4, :, :, :, :],
                                             label13: angio[:, 5, :, :, :, :],
                                             label14: angio[:, 6, :, :, :, :],
                                             is_training: False,
                                             input_dim: self.patch_window,
                                             all_loss: -1.,
                                             })
                    loss_validation += loss_vali

                    validation_step += 1
                    if np.isnan(dsc_validation) or np.isnan(loss_validation) or np.isnan(acc_validation):
                        print('nan problem')
                    process = psutil.Process(os.getpid())
                    print(
                        '%d - > %d:   loss_validation: %f, memory_percent: %4s' % (
                            validation_step, validation_step * self.batch_no_validation
                            , loss_vali, str(process.memory_percent()),
                        ))

                settings.queue_isready_vl = False
                acc_validation = acc_validation / (validation_step)
                loss_validation = loss_validation / (validation_step)
                loss_validation_p_ssim = loss_validation_p_ssim / (validation_step)
                loss_validation_p_perceptual = loss_validation_p_perceptual / (validation_step)
                dsc_validation = dsc_validation / (validation_step)

                if np.isnan(dsc_validation) or np.isnan(loss_validation) or np.isnan(acc_validation):
                    print('nan problem')
                _fill_thread_vl.kill_thread()
                print('******Validation, step: %d , accuracy: %.4f, loss: %f*******' % (
                    point, acc_validation, loss_validation))
                [sum_validation] = sess.run([summ],
                                            feed_dict={img_row1: crush[:, 0, :, :, :, :],
                                                       img_row2: noncrush[:, 1, :, :, :, :],
                                                       img_row3: crush[:, 2, :, :, :, :],
                                                       img_row4: noncrush[:, 3, :, :, :, :],
                                                       img_row5: crush[:, 4, :, :, :, :],
                                                       img_row6: noncrush[:, 5, :, :, :, :],
                                                       img_row7: crush[:, 6, :, :, :, :],
                                                       img_row8: noncrush[:, 7, :, :, :, :],
                                                       mri_ph: mri[:, 0, :, :, :, :],
                                                       label1: perf[:, 0, :, :, :, :],
                                                       label2: perf[:, 1, :, :, :, :],
                                                       label3: perf[:, 2, :, :, :, :],
                                                       label4: perf[:, 3, :, :, :, :],
                                                       label5: perf[:, 4, :, :, :, :],
                                                       label6: perf[:, 5, :, :, :, :],
                                                       label7: perf[:, 6, :, :, :, :],
                                                       label8: angio[:, 0, :, :, :, :],
                                                       label9: angio[:, 1, :, :, :, :],
                                                       label10: angio[:, 2, :, :, :, :],
                                                       label11: angio[:, 3, :, :, :, :],
                                                       label12: angio[:, 4, :, :, :, :],
                                                       label13: angio[:, 5, :, :, :, :],
                                                       label14: angio[:, 6, :, :, :, :],
                                                       is_training: False,
                                                       input_dim: self.patch_window,
                                                       all_loss: loss_validation,
                                                       })
                validation_writer.add_summary(sum_validation, point)
                print('end of validation---------%d' % (point))

                '''loop for training batches'''
                while (step * self.batch_no < self.no_sample_per_each_itr):

                    # [train_CT_image_patchs, train_GTV_label, train_Penalize_patch,loss_coef_weights] = _image_class.return_patches( self.batch_no)

                    [crush, noncrush, perf, angio,mri,segmentation] = _image_class.return_patches_tr(self.batch_no)

                    if (len(segmentation) < self.batch_no):
                        time.sleep(0.5)
                        _read_thread.resume()
                        continue

                    if point % self.display_train_step == 0:
                        '''train: '''
                        train_step=0

                    while (train_step <self.display_train_step):
                        if train_step==0:
                            average_loss_train1=0
                        [loss_train1, optimizing, out ] = sess.run([cost, optimizer, y,],
                                      feed_dict={img_row1: crush[:, 0, :, :, :, :],
                                                 img_row2: noncrush[:, 1, :, :, :, :],
                                                 img_row3: crush[:, 2, :, :, :, :],
                                                 img_row4: noncrush[:, 3, :, :, :, :],
                                                 img_row5: crush[:, 4, :, :, :, :],
                                                 img_row6: noncrush[:, 5, :, :, :, :],
                                                 img_row7: crush[:, 6, :, :, :, :],
                                                 img_row8: noncrush[:, 7, :, :, :, :],
                                                 mri_ph: mri[:, 0, :, :, :, :],
                                                 label1: perf[:, 0, :, :, :, :],
                                                 label2: perf[:, 1, :, :, :, :],
                                                 label3: perf[:, 2, :, :, :, :],
                                                 label4: perf[:, 3, :, :, :, :],
                                                 label5: perf[:, 4, :, :, :, :],
                                                 label6: perf[:, 5, :, :, :, :],
                                                 label7: perf[:, 6, :, :, :, :],
                                                 label8: angio[:, 0, :, :, :, :],
                                                 label9: angio[:, 1, :, :, :, :],
                                                 label10: angio[:, 2, :, :, :, :],
                                                 label11: angio[:, 3, :, :, :, :],
                                                 label12: angio[:, 4, :, :, :, :],
                                                 label13: angio[:, 5, :, :, :, :],
                                                 label14: angio[:, 6, :, :, :, :],
                                                 is_training: False,
                                                 input_dim: self.patch_window,
                                                 all_loss: -1,
                                                 })
                        average_loss_train1+=loss_train1

                        train_step=train_step+1
                        print(
                            'point: %d, step*self.batch_no:%f , LR: %.15f, loss_train1:%f,memory_percent: %4s, ' % (
                                int((point+train_step)),
                                step * self.batch_no, self.learning_rate, loss_train1,
                                str(process.memory_percent())))

                    average_loss_train1 /= self.display_train_step


                    [sum_train] = sess.run([summ],
                                           feed_dict={img_row1: crush[:, 0, :, :, :, :],
                                                      img_row2: noncrush[:, 1, :, :, :, :],
                                                      img_row3: crush[:, 2, :, :, :, :],
                                                      img_row4: noncrush[:, 3, :, :, :, :],
                                                      img_row5: crush[:, 4, :, :, :, :],
                                                      img_row6: noncrush[:, 5, :, :, :, :],
                                                      img_row7: crush[:, 6, :, :, :, :],
                                                      img_row8: noncrush[:, 7, :, :, :, :],
                                                      mri_ph: mri[:, 0, :, :, :, :],
                                                      label1: perf[:, 0, :, :, :, :],
                                                      label2: perf[:, 1, :, :, :, :],
                                                      label3: perf[:, 2, :, :, :, :],
                                                      label4: perf[:, 3, :, :, :, :],
                                                      label5: perf[:, 4, :, :, :, :],
                                                      label6: perf[:, 5, :, :, :, :],
                                                      label7: perf[:, 6, :, :, :, :],
                                                      label8: angio[:, 0, :, :, :, :],
                                                      label9: angio[:, 1, :, :, :, :],
                                                      label10: angio[:, 2, :, :, :, :],
                                                      label11: angio[:, 3, :, :, :, :],
                                                      label12: angio[:, 4, :, :, :, :],
                                                      label13: angio[:, 5, :, :, :, :],
                                                      label14: angio[:, 6, :, :, :, :],
                                                      is_training: False,
                                                      input_dim: self.patch_window,
                                                      all_loss: loss_train1,


                                                      })
                    train_writer.add_summary(sum_train, point)
                    step = step + 1

                    point = point + self.display_train_step
                    if point % 200 == 0:
                        break
                    process = psutil.Process(os.getpid())
                    print(
                        'point: %d, step*self.batch_no:%f , LR: %.15f, loss_train1:%f,memory_percent: %4s' % (
                            int((point)),
                            step * self.batch_no, self.learning_rate, loss_train1,
                            str(process.memory_percent())))
                    point = int((point))  # (self.no_sample_per_each_itr/self.batch_no)*itr1+step

                    if point % 100 == 0:
                        '''saveing model inter epoch'''
                        chckpnt_path = os.path.join(self.chckpnt_dir,
                                                    ('unet_inter_epoch%d_point%d.ckpt' % (epoch, point)))
                        saver.save(sess, chckpnt_path, global_step=point)



            endTime = time.time()
            # ==============end of epoch:
            '''saveing model after each epoch'''
            chckpnt_path = os.path.join(self.chckpnt_dir, 'unet.ckpt')
            saver.save(sess, chckpnt_path, global_step=epoch)

            print("End of epoch----> %d, elapsed time: %d" % (epoch, endTime - startTime))
def test_all_nets():
    data = 2

    Server = 'DL'

    if Server == 'DL':
        parent_path = '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/sythesize_code/ASL_LOG/multi_stage/experiment-2/'
        data_path = '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Data-01/BrainWeb_permutation00_low/'
    else:
        parent_path = '/exports/lkeb-hpc/syousefi/Code/'
        data_path = '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Data-01/BrainWeb_permutation00_low/'

    img_name = ''
    label_name = ''

    _rd = _read_data(data=data,
                     reverse=False,
                     img_name=img_name,
                     label_name=label_name,
                     dataset_path=data_path)
    '''read path of the images for train, test, and validation'''
    train_data, validation_data, test_data = _rd.read_data_path()

    chckpnt_dir = parent_path + 'unet_checkpoints/'
    result_path = parent_path + 'results/'
    batch_no = 1
    batch_no_validation = batch_no
    # label_patchs_size = 87#39  # 63
    # patch_window = 103#53  # 77#89
    if test_vali == 1:
        test_set = validation_data
    else:
        test_set = test_data
    # ===================================================================================
    # img_row1 = tf.placeholder(tf.float32,
    #                           shape=[batch_no, patch_window, patch_window, patch_window, 1],
    #                           name='img_row1')
    # img_row2 = tf.placeholder(tf.float32,
    #                           shape=[batch_no, patch_window, patch_window, patch_window, 1],
    #                           name='img_row2')
    # img_row3 = tf.placeholder(tf.float32,
    #                           shape=[batch_no, patch_window, patch_window, patch_window, 1],
    #                           name='img_row3')
    # img_row4 = tf.placeholder(tf.float32,
    #                           shape=[batch_no, patch_window, patch_window, patch_window, 1],
    #                           name='img_row4')
    # img_row5 = tf.placeholder(tf.float32,
    #                           shape=[batch_no, patch_window, patch_window, patch_window, 1],
    #                           name='img_row5')
    # img_row6 = tf.placeholder(tf.float32,
    #                           shape=[batch_no, patch_window, patch_window, patch_window, 1],
    #                           name='img_row6')
    # img_row7 = tf.placeholder(tf.float32,
    #                           shape=[batch_no, patch_window, patch_window, patch_window, 1],
    #                           name='img_row7')
    # img_row8 = tf.placeholder(tf.float32,
    #                           shape=[batch_no, patch_window, patch_window, patch_window, 1],
    #                           name='img_row8')
    #
    # mri_ph = tf.placeholder(tf.float32,
    #                         shape=[batch_no, patch_window, patch_window, patch_window, 1],
    #                         name='mri')
    #
    # segmentation = tf.placeholder(tf.float32,
    #                               shape=[batch_no, label_patchs_size, label_patchs_size,
    #                                      label_patchs_size, 1], name='segments')
    #
    # label1 = tf.placeholder(tf.float32, shape=[batch_no, label_patchs_size, label_patchs_size,
    #                                            label_patchs_size, 1], name='label1')
    # label2 = tf.placeholder(tf.float32, shape=[batch_no, label_patchs_size, label_patchs_size,
    #                                            label_patchs_size, 1], name='label2')
    # label3 = tf.placeholder(tf.float32, shape=[batch_no, label_patchs_size, label_patchs_size,
    #                                            label_patchs_size, 1], name='label3')
    # label4 = tf.placeholder(tf.float32, shape=[batch_no, label_patchs_size, label_patchs_size,
    #                                            label_patchs_size, 1], name='label4')
    # label5 = tf.placeholder(tf.float32, shape=[batch_no, label_patchs_size, label_patchs_size,
    #                                            label_patchs_size, 1], name='label5')
    # label6 = tf.placeholder(tf.float32, shape=[batch_no, label_patchs_size, label_patchs_size,
    #                                            label_patchs_size, 1], name='label6')
    # label7 = tf.placeholder(tf.float32, shape=[batch_no, label_patchs_size, label_patchs_size,
    #                                            label_patchs_size, 1], name='label7')
    # label8 = tf.placeholder(tf.float32, shape=[batch_no, label_patchs_size, label_patchs_size,
    #                                            label_patchs_size, 1], name='label8')
    # label9 = tf.placeholder(tf.float32, shape=[batch_no, label_patchs_size, label_patchs_size,
    #                                            label_patchs_size, 1], name='label9')
    # label10 = tf.placeholder(tf.float32, shape=[batch_no, label_patchs_size, label_patchs_size,
    #                                             label_patchs_size, 1], name='label10')
    # label11 = tf.placeholder(tf.float32, shape=[batch_no, label_patchs_size, label_patchs_size,
    #                                             label_patchs_size, 1], name='label11')
    # label12 = tf.placeholder(tf.float32, shape=[batch_no, label_patchs_size, label_patchs_size,
    #                                             label_patchs_size, 1], name='label12')
    # label13 = tf.placeholder(tf.float32, shape=[batch_no, label_patchs_size, label_patchs_size,
    #                                             label_patchs_size, 1], name='label13')
    # label14 = tf.placeholder(tf.float32, shape=[batch_no, label_patchs_size, label_patchs_size,
    #                                             label_patchs_size, 1], name='label14')

    img_row1 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row2 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row3 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row4 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row5 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row6 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row7 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row8 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    mri_ph = tf.placeholder(tf.float32,
                            shape=[None, None, None, None, 1],
                            name='mri')
    # segmentation = tf.placeholder(tf.float32, shape=[None, None, None, None, 1], name='segmentation')
    label1 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label2 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label3 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label4 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label5 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label6 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label7 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label8 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label9 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label10 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label11 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label12 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label13 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label14 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])

    all_loss = tf.placeholder(tf.float32, name='loss')
    is_training = tf.placeholder(tf.bool, name='is_training')
    input_dim = tf.placeholder(tf.int32, name='input_dim')
    # ave_huber = tf.placeholder(tf.float32, name='huber')

    multi_stage_densenet = _multi_stage_densenet()
    y, loss_upsampling11, loss_upsampling22 = multi_stage_densenet.multi_stage_densenet(
        img_row1=img_row1,
        img_row2=img_row2,
        img_row3=img_row3,
        img_row4=img_row4,
        img_row5=img_row5,
        img_row6=img_row6,
        img_row7=img_row7,
        img_row8=img_row8,
        input_dim=input_dim,
        mri=mri_ph,
        is_training=is_training)

    loss_instance = _loss_func()
    labels = []
    labels.append(label1)
    labels.append(label2)
    labels.append(label3)
    labels.append(label4)
    labels.append(label5)
    labels.append(label6)
    labels.append(label7)
    labels.append(label8)
    labels.append(label9)
    labels.append(label10)
    labels.append(label11)
    labels.append(label12)
    labels.append(label13)
    labels.append(label14)

    logits = []
    logits.append(y[:, :, :, :, 0, np.newaxis])
    logits.append(y[:, :, :, :, 1, np.newaxis])
    logits.append(y[:, :, :, :, 2, np.newaxis])
    logits.append(y[:, :, :, :, 3, np.newaxis])
    logits.append(y[:, :, :, :, 4, np.newaxis])
    logits.append(y[:, :, :, :, 5, np.newaxis])
    logits.append(y[:, :, :, :, 6, np.newaxis])
    logits.append(y[:, :, :, :, 7, np.newaxis])
    logits.append(y[:, :, :, :, 8, np.newaxis])
    logits.append(y[:, :, :, :, 9, np.newaxis])
    logits.append(y[:, :, :, :, 10, np.newaxis])
    logits.append(y[:, :, :, :, 11, np.newaxis])
    logits.append(y[:, :, :, :, 12, np.newaxis])
    logits.append(y[:, :, :, :, 13, np.newaxis])
    stage1 = []
    stage1.append(loss_upsampling11[:, :, :, :, 0, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 1, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 2, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 3, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 4, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 5, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 6, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 7, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 8, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 9, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 10, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 11, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 12, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 13, np.newaxis])

    stage2 = []
    stage2.append(loss_upsampling22[:, :, :, :, 0, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 1, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 2, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 3, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 4, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 5, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 6, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 7, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 8, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 9, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 10, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 11, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 12, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 13, np.newaxis])

    with tf.name_scope('Loss'):
        loss_dic = loss_instance.loss_selector(
            'Multistage_ssim_perf_angio_loss',
            labels=labels,
            logits=logits,
            stage1=stage1,
            stage2=stage2)
        cost = tf.reduce_mean(loss_dic["loss"], name="cost")
        # cost_angio = tf.reduce_mean(loss_dic["angio_SSIM"], name="angio_SSIM")
        # cost_perf = tf.reduce_mean(loss_dic["perf_SSIM"], name="perf_SSIM")

    # ========================================================================
    # ave_loss = tf.placeholder(tf.float32, name='loss')
    # ave_loss_perf = tf.placeholder(tf.float32, name='loss_perf')
    # ave_loss_angio = tf.placeholder(tf.float32, name='loss_angio')
    #
    # average_gradient_perf = tf.placeholder(tf.float32, name='grad_ave_perf')
    # average_gradient_angio = tf.placeholder(tf.float32, name='grad_ave_angio')
    #
    # ave_huber = tf.placeholder(tf.float32, name='huber')
    # restore the model
    sess = tf.Session()
    saver = tf.train.Saver()

    ckpt = tf.train.get_checkpoint_state(chckpnt_dir)
    saver.restore(sess, ckpt.model_checkpoint_path)

    copyfile('./test_synthesize_multistage_perf_angio.py',
             result_path + '/test_synthesize_multistage_perf_angio.py')

    _image_class = image_class(train_data,
                               bunch_of_images_no=1,
                               is_training=1,
                               patch_window=patch_window,
                               sample_no_per_bunch=1,
                               label_patch_size=label_patchs_size,
                               validation_total_sample=0)
    learning_rate = 1E-5
    extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(extra_update_ops):
        optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)
        # init = tf.global_variables_initializer()

    dic_perf0 = []
    dic_perf1 = []
    dic_perf2 = []
    dic_perf3 = []
    dic_perf4 = []
    dic_perf5 = []
    dic_perf6 = []

    dic_angio0 = []
    dic_angio1 = []
    dic_angio2 = []
    dic_angio3 = []
    dic_angio4 = []
    dic_angio5 = []
    dic_angio6 = []
    loss = 0
    Elapsed = []
    for img_indx in range(2):  #len(test_set)):
        crush, noncrush, perf, angio, mri, segmentation_, spacing, direction, origin = _image_class.read_image_for_test(
            test_set=test_set,
            img_indx=img_indx,
            input_size=in_dim,
            final_layer=final_layer)
        t = time.time()

        [out] = sess.run(
            [y],
            feed_dict={
                img_row1: np.expand_dims(np.expand_dims(crush[0], 0), -1),
                img_row2: np.expand_dims(np.expand_dims(noncrush[1], 0), -1),
                img_row3: np.expand_dims(np.expand_dims(crush[2], 0), -1),
                img_row4: np.expand_dims(np.expand_dims(noncrush[3], 0), -1),
                img_row5: np.expand_dims(np.expand_dims(crush[4], 0), -1),
                img_row6: np.expand_dims(np.expand_dims(noncrush[5], 0), -1),
                img_row7: np.expand_dims(np.expand_dims(crush[6], 0), -1),
                img_row8: np.expand_dims(np.expand_dims(noncrush[7], 0), -1),
                mri_ph: np.expand_dims(np.expand_dims(mri, 0), -1),
                label1: np.expand_dims(np.expand_dims(perf[0], 0), -1),
                label2: np.expand_dims(np.expand_dims(perf[1], 0), -1),
                label3: np.expand_dims(np.expand_dims(perf[2], 0), -1),
                label4: np.expand_dims(np.expand_dims(perf[3], 0), -1),
                label5: np.expand_dims(np.expand_dims(perf[4], 0), -1),
                label6: np.expand_dims(np.expand_dims(perf[5], 0), -1),
                label7: np.expand_dims(np.expand_dims(perf[6], 0), -1),
                label8: np.expand_dims(np.expand_dims(angio[0], 0), -1),
                label9: np.expand_dims(np.expand_dims(angio[1], 0), -1),
                label10: np.expand_dims(np.expand_dims(angio[2], 0), -1),
                label11: np.expand_dims(np.expand_dims(angio[3], 0), -1),
                label12: np.expand_dims(np.expand_dims(angio[4], 0), -1),
                label13: np.expand_dims(np.expand_dims(angio[5], 0), -1),
                label14: np.expand_dims(np.expand_dims(angio[6], 0), -1),
                is_training: False,
                input_dim: patch_window,
                all_loss: -1.,
            })
        elapsed = time.time() - t
        Elapsed.append(elapsed)
        print(elapsed)
        for i in range(np.shape(out)[-1]):
            image = out[0, :, :, :, i]
            sitk_image = sitk.GetImageFromArray(image)
            res_dir = test_set[img_indx][0][0].split('/')[-2]
            if i == 0:
                os.mkdir(parent_path + 'results/' + res_dir)
            if i < 7:
                nm = 'perf'
            else:
                nm = 'angi'
            sitk_image.SetDirection(direction=direction)
            sitk_image.SetOrigin(origin=origin)
            sitk_image.SetSpacing(spacing=spacing)
            sitk.WriteImage(
                sitk_image, parent_path + 'results/' + res_dir + '/' + nm +
                '_' + str(i % 7) + '.mha')
            print(parent_path + 'results/' + res_dir + '/' + nm + '_' +
                  str(i % 7) + '.mha done!')
        for i in range(7):
            if i == 0:
                os.mkdir(parent_path + 'results/' + res_dir + '/GT/')
            sitk_angio = sitk.GetImageFromArray(angio[i])
            sitk_angio.SetDirection(direction=direction)
            sitk_angio.SetOrigin(origin=origin)
            sitk_angio.SetSpacing(spacing=spacing)
            sitk.WriteImage(
                sitk_angio, parent_path + 'results/' + res_dir + '/GT/angio_' +
                str(i) + '.mha')

            sitk_perf = sitk.GetImageFromArray(perf[i])
            sitk_perf.SetDirection(direction=direction)
            sitk_perf.SetOrigin(origin=origin)
            sitk_perf.SetSpacing(spacing=spacing)
            sitk.WriteImage(
                sitk_perf, parent_path + 'results/' + res_dir + '/GT/perf_' +
                str(i) + '.mha')

        dic_perf0.append(
            anly.analysis(out[0, :, :, :, 0], perf[i], 0, max_perf))
        dic_perf1.append(
            anly.analysis(out[0, :, :, :, 1], perf[i], 0, max_perf))
        dic_perf2.append(
            anly.analysis(out[0, :, :, :, 2], perf[i], 0, max_perf))
        dic_perf3.append(
            anly.analysis(out[0, :, :, :, 3], perf[i], 0, max_perf))
        dic_perf4.append(
            anly.analysis(out[0, :, :, :, 4], perf[i], 0, max_perf))
        dic_perf5.append(
            anly.analysis(out[0, :, :, :, 5], perf[i], 0, max_perf))
        dic_perf6.append(
            anly.analysis(out[0, :, :, :, 6], perf[i], 0, max_perf))

        dic_angio0.append(
            anly.analysis(out[0, :, :, :, 7], angio[i], 0, max_angio))
        dic_angio1.append(
            anly.analysis(out[0, :, :, :, 8], angio[i], 0, max_angio))
        dic_angio2.append(
            anly.analysis(out[0, :, :, :, 9], angio[i], 0, max_angio))
        dic_angio3.append(
            anly.analysis(out[0, :, :, :, 10], angio[i], 0, max_angio))
        dic_angio4.append(
            anly.analysis(out[0, :, :, :, 11], angio[i], 0, max_angio))
        dic_angio5.append(
            anly.analysis(out[0, :, :, :, 12], angio[i], 0, max_angio))
        dic_angio6.append(
            anly.analysis(out[0, :, :, :, 13], angio[i], 0, max_angio))
        if img_indx == 0:
            headers = dic_perf0[0].keys()
        dics = [
            dic_perf0, dic_perf1, dic_perf2, dic_perf3, dic_perf4, dic_perf5,
            dic_perf6, dic_angio0, dic_angio1, dic_angio2, dic_angio3,
            dic_angio4, dic_angio5, dic_angio6
        ]

    # print(np.mean(Elapsed))
    # print(np.std(Elapsed))
    save_in_xlsx(parent_path, headers, dics=dics)
    print('Total loss: ', loss / len(test_set))
コード例 #5
0
def test_all_nets():
    data = 2

    Server = 'DL'
    if Server == 'DL':
        parent_path = '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/sythesize_code/ASL_LOG/MRI_in/experiment-21/'
        data_path = '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Data-01/BrainWeb_permutation2_low/'
    else:
        parent_path = '/exports/lkeb-hpc/syousefi/Code/'
        data_path = '/exports/lkeb-hpc/syousefi/Synth_Data/BrainWeb_permutation2_low/'

    img_name = ''
    label_name = ''

    _rd = _read_data(data=data,
                     reverse=False,
                     img_name=img_name,
                     label_name=label_name,
                     dataset_path=data_path)
    '''read path of the images for train, test, and validation'''
    train_data, validation_data, test_data = _rd.read_data_path()

    chckpnt_dir = parent_path + 'unet_checkpoints/'
    result_path = parent_path + 'results/'

    if test_vali == 1:
        test_set = validation_data
    else:
        test_set = test_data
    # image=tf.placeholder(tf.float32,shape=[batch_no,patch_window,patch_window,patch_window,1])
    # label=tf.placeholder(tf.float32,shape=[batch_no_validation,label_patchs_size,label_patchs_size,label_patchs_size,2])
    # loss_coef=tf.placeholder(tf.float32,shape=[batch_no_validation,1,1,1])

    # img_row1 = tf.placeholder(tf.float32, shape=[batch_no,patch_window,patch_window,patch_window, 1])
    # img_row2 = tf.placeholder(tf.float32, shape=[batch_no,patch_window,patch_window,patch_window, 1])
    # img_row3 = tf.placeholder(tf.float32, shape=[batch_no,patch_window,patch_window,patch_window, 1])
    # img_row4 = tf.placeholder(tf.float32, shape=[batch_no,patch_window,patch_window,patch_window, 1])
    # img_row5 = tf.placeholder(tf.float32, shape=[batch_no,patch_window,patch_window,patch_window, 1])
    # img_row6 = tf.placeholder(tf.float32, shape=[batch_no,patch_window,patch_window,patch_window, 1])
    # img_row7 = tf.placeholder(tf.float32, shape=[batch_no,patch_window,patch_window,patch_window, 1])
    # img_row8 = tf.placeholder(tf.float32, shape=[batch_no,patch_window,patch_window,patch_window, 1])
    #
    # label1 = tf.placeholder(tf.float32, shape=[batch_no,label_patchs_size,label_patchs_size,label_patchs_size, 1])
    # label2 = tf.placeholder(tf.float32, shape=[batch_no,label_patchs_size,label_patchs_size,label_patchs_size, 1])
    # label3 = tf.placeholder(tf.float32, shape=[batch_no,label_patchs_size,label_patchs_size,label_patchs_size, 1])
    # label4 = tf.placeholder(tf.float32, shape=[batch_no,label_patchs_size,label_patchs_size,label_patchs_size, 1])
    # label5 = tf.placeholder(tf.float32, shape=[batch_no,label_patchs_size,label_patchs_size,label_patchs_size, 1])
    # label6 = tf.placeholder(tf.float32, shape=[batch_no,label_patchs_size,label_patchs_size,label_patchs_size, 1])
    # label7 = tf.placeholder(tf.float32, shape=[batch_no,label_patchs_size,label_patchs_size,label_patchs_size, 1])
    # label8 = tf.placeholder(tf.float32, shape=[batch_no,label_patchs_size,label_patchs_size,label_patchs_size, 1])
    # label9 = tf.placeholder(tf.float32, shape=[batch_no,label_patchs_size,label_patchs_size,label_patchs_size, 1])
    # label10 = tf.placeholder(tf.float32, shape=[batch_no,label_patchs_size,label_patchs_size,label_patchs_size, 1])
    # label11 = tf.placeholder(tf.float32, shape=[batch_no,label_patchs_size,label_patchs_size,label_patchs_size, 1])
    # label12 = tf.placeholder(tf.float32, shape=[batch_no,label_patchs_size,label_patchs_size,label_patchs_size, 1])
    # label13 = tf.placeholder(tf.float32, shape=[batch_no,label_patchs_size,label_patchs_size,label_patchs_size, 1])
    # label14 = tf.placeholder(tf.float32, shape=[batch_no,label_patchs_size,label_patchs_size,label_patchs_size, 1])
    img_row1 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row2 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row3 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row4 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row5 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row6 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row7 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row8 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    mri_ph = tf.placeholder(tf.float32,
                            shape=[None, None, None, None, 1],
                            name='mri')
    label1 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label2 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label3 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label4 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label5 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label6 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label7 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    # label8 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    # label9 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    # label10 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    # label11 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    # label12 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    # label13 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    # label14 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])

    all_loss = tf.placeholder(tf.float32, name='loss')
    is_training = tf.placeholder(tf.bool, name='is_training')
    input_dim = tf.placeholder(tf.int32, name='input_dim')
    # ave_huber = tf.placeholder(tf.float32, name='huber')

    densenet = _densenet()

    y, degree = densenet.densenet(img_row1=img_row1,
                                  img_row2=img_row2,
                                  img_row3=img_row3,
                                  img_row4=img_row4,
                                  img_row5=img_row5,
                                  img_row6=img_row6,
                                  img_row7=img_row7,
                                  img_row8=img_row8,
                                  input_dim=input_dim,
                                  mri=mri_ph,
                                  is_training=is_training)

    loss_instance = _loss_func()
    labels = []
    labels.append(label1)
    labels.append(label2)
    labels.append(label3)
    labels.append(label4)
    labels.append(label5)
    labels.append(label6)
    labels.append(label7)

    logits = []
    logits.append(y[:, :, :, :, 0, np.newaxis])
    logits.append(y[:, :, :, :, 1, np.newaxis])
    logits.append(y[:, :, :, :, 2, np.newaxis])
    logits.append(y[:, :, :, :, 3, np.newaxis])
    logits.append(y[:, :, :, :, 4, np.newaxis])
    logits.append(y[:, :, :, :, 5, np.newaxis])
    logits.append(y[:, :, :, :, 6, np.newaxis])
    with tf.name_scope('Loss'):
        loss_dic = loss_instance.loss_selector('SSIM_perf',
                                               labels=labels,
                                               logits=logits)
        cost = tf.reduce_mean(loss_dic["loss"], name="cost")

    # ========================================================================
    # ave_loss = tf.placeholder(tf.float32, name='loss')
    # ave_loss_perf = tf.placeholder(tf.float32, name='loss_perf')
    # ave_loss_angio = tf.placeholder(tf.float32, name='loss_angio')
    #
    # average_gradient_perf = tf.placeholder(tf.float32, name='grad_ave_perf')
    # average_gradient_angio = tf.placeholder(tf.float32, name='grad_ave_angio')
    #
    # ave_huber = tf.placeholder(tf.float32, name='huber')
    # restore the model
    sess = tf.Session()
    saver = tf.train.Saver()

    ckpt = tf.train.get_checkpoint_state(chckpnt_dir)
    saver.restore(sess, ckpt.model_checkpoint_path)

    copyfile('./test_synthesize_net_mri.py',
             result_path + '/test_synthesize_net_mri.py')

    _image_class = image_class(train_data,
                               bunch_of_images_no=1,
                               is_training=1,
                               patch_window=patch_window,
                               sample_no_per_bunch=1,
                               label_patch_size=label_patchs_size,
                               validation_total_sample=0)
    learning_rate = 1E-5
    extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(extra_update_ops):
        optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)
        # init = tf.global_variables_initializer()

    loss = 0
    for img_indx in range(len(test_set)):
        crush, noncrush, perf, angio, mri, segmentation, spacing, direction, origin = _image_class.read_image_for_test(
            test_set=test_set,
            img_indx=img_indx,
            input_size=in_dim,
            final_layer=final_layer)

        [out] = sess.run(
            [y],
            feed_dict={
                img_row1: np.expand_dims(np.expand_dims(crush[0], 0), -1),
                img_row2: np.expand_dims(np.expand_dims(noncrush[1], 0), -1),
                img_row3: np.expand_dims(np.expand_dims(crush[2], 0), -1),
                img_row4: np.expand_dims(np.expand_dims(noncrush[3], 0), -1),
                img_row5: np.expand_dims(np.expand_dims(crush[4], 0), -1),
                img_row6: np.expand_dims(np.expand_dims(noncrush[5], 0), -1),
                img_row7: np.expand_dims(np.expand_dims(crush[6], 0), -1),
                img_row8: np.expand_dims(np.expand_dims(noncrush[7], 0), -1),
                mri_ph: np.expand_dims(np.expand_dims(mri, 0), -1),
                label1: np.expand_dims(np.expand_dims(perf[0], 0), -1),
                label2: np.expand_dims(np.expand_dims(perf[1], 0), -1),
                label3: np.expand_dims(np.expand_dims(perf[2], 0), -1),
                label4: np.expand_dims(np.expand_dims(perf[3], 0), -1),
                label5: np.expand_dims(np.expand_dims(perf[4], 0), -1),
                label6: np.expand_dims(np.expand_dims(perf[5], 0), -1),
                label7: np.expand_dims(np.expand_dims(perf[6], 0), -1),
                is_training: False,
                input_dim: patch_window,
                all_loss: -1.,
            })

        for i in range(np.shape(out)[-1]):
            image = out[0, :, :, :, i]
            sitk_image = sitk.GetImageFromArray(image)
            res_dir = test_set[img_indx][0][0].split('/')[-2]
            if i == 0:
                os.mkdir(parent_path + 'results/' + res_dir)
            if i < 7:
                nm = 'perf'
            else:
                nm = 'angi'
            sitk_image.SetDirection(direction=direction)
            sitk_image.SetOrigin(origin=origin)
            sitk_image.SetSpacing(spacing=spacing)
            sitk.WriteImage(
                sitk_image, parent_path + 'results/' + res_dir + '/' + nm +
                '_' + str(i % 7) + '.mha')
            print(parent_path + 'results/' + res_dir + '/' + nm + '_' +
                  str(i % 7) + '.mha done!')
        for i in range(7):
            if i == 0:
                os.mkdir(parent_path + 'results/' + res_dir + '/GT/')
            # sitk_angio=sitk.GetImageFromArray(angio[i])
            # sitk_angio.SetDirection(direction=direction)
            # sitk_angio.SetOrigin(origin=origin)
            # sitk_angio.SetSpacing(spacing=spacing)
            # sitk.WriteImage(sitk_angio, parent_path + 'results/' + res_dir + '/GT/angio_' + str(i) + '.mha')

            sitk_perf = sitk.GetImageFromArray(perf[i])
            sitk_perf.SetDirection(direction=direction)
            sitk_perf.SetOrigin(origin=origin)
            sitk_perf.SetSpacing(spacing=spacing)
            sitk.WriteImage(
                sitk_perf, parent_path + 'results/' + res_dir + '/GT/perf_' +
                str(i) + '.mha')
        a = 1

        # plt.imshow(out[0, int(gt_cube_size / 2), :, :, 0])
        # plt.figure()
        # loss += loss_train1
        # print('Loss_train: ', loss_train1)

    print('Total loss: ', loss / len(test_set))
コード例 #6
0
def test_all_nets(newdataset):
    data = 2

    Server = 'shark'

    if Server == 'DL':
        parent_path = '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/sythesize_code/'
        if newdataset == True:
            data_path = '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Data-01/BrainWeb_permutation00_low/'
        else:
            data_path = '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Data-01/BrainWeb_permutation2_low/'
    else:
        parent_path = '/exports/lkeb-hpc/syousefi/Code/'
        if newdataset == True:
            data_path = '/exports/lkeb-hpc/syousefi/Synth_Data/BrainWeb_permutation00_low/'
        else:
            data_path = '/exports/lkeb-hpc/syousefi/Synth_Data/BrainWeb_permutation2_low/'

    img_name = ''
    label_name = ''

    _rd = _read_data(data=data,
                     img_name=img_name,
                     label_name=label_name,
                     dataset_path=data_path)
    '''read path of the images for train, test, and validation'''
    train_data, validation_data, test_data = _rd.read_data_path()
    # parent_path='/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/sythesize_code/Log/synth-forked_synthesizing_net_rotate-1/'
    parent_path = '/exports/lkeb-hpc/syousefi/Code/ASL_LOG/debug_Log/synth-12/'

    chckpnt_dir = parent_path + 'unet_checkpoints/'

    if test_vali == 1:
        test_set = validation_data
        result_path = parent_path + 'results/'
    elif test_vali == 2:
        test_set = train_data
        result_path = parent_path + 'results_tr/'
    else:
        test_set = test_data
        result_path = parent_path + 'results_vali/'
    # image=tf.placeholder(tf.float32,shape=[batch_no,patch_window,patch_window,patch_window,1])
    # label=tf.placeholder(tf.float32,shape=[batch_no_validation,label_patchs_size,label_patchs_size,label_patchs_size,2])
    # loss_coef=tf.placeholder(tf.float32,shape=[batch_no_validation,1,1,1])

    # img_row1 = tf.placeholder(tf.float32, shape=[batch_no,patch_window,patch_window,patch_window, 1])
    # img_row2 = tf.placeholder(tf.float32, shape=[batch_no,patch_window,patch_window,patch_window, 1])
    # img_row3 = tf.placeholder(tf.float32, shape=[batch_no,patch_window,patch_window,patch_window, 1])
    # img_row4 = tf.placeholder(tf.float32, shape=[batch_no,patch_window,patch_window,patch_window, 1])
    # img_row5 = tf.placeholder(tf.float32, shape=[batch_no,patch_window,patch_window,patch_window, 1])
    # img_row6 = tf.placeholder(tf.float32, shape=[batch_no,patch_window,patch_window,patch_window, 1])
    # img_row7 = tf.placeholder(tf.float32, shape=[batch_no,patch_window,patch_window,patch_window, 1])
    # img_row8 = tf.placeholder(tf.float32, shape=[batch_no,patch_window,patch_window,patch_window, 1])
    #
    # label1 = tf.placeholder(tf.float32, shape=[batch_no,label_patchs_size,label_patchs_size,label_patchs_size, 1])
    # label2 = tf.placeholder(tf.float32, shape=[batch_no,label_patchs_size,label_patchs_size,label_patchs_size, 1])
    # label3 = tf.placeholder(tf.float32, shape=[batch_no,label_patchs_size,label_patchs_size,label_patchs_size, 1])
    # label4 = tf.placeholder(tf.float32, shape=[batch_no,label_patchs_size,label_patchs_size,label_patchs_size, 1])
    # label5 = tf.placeholder(tf.float32, shape=[batch_no,label_patchs_size,label_patchs_size,label_patchs_size, 1])
    # label6 = tf.placeholder(tf.float32, shape=[batch_no,label_patchs_size,label_patchs_size,label_patchs_size, 1])
    # label7 = tf.placeholder(tf.float32, shape=[batch_no,label_patchs_size,label_patchs_size,label_patchs_size, 1])
    # label8 = tf.placeholder(tf.float32, shape=[batch_no,label_patchs_size,label_patchs_size,label_patchs_size, 1])
    # label9 = tf.placeholder(tf.float32, shape=[batch_no,label_patchs_size,label_patchs_size,label_patchs_size, 1])
    # label10 = tf.placeholder(tf.float32, shape=[batch_no,label_patchs_size,label_patchs_size,label_patchs_size, 1])
    # label11 = tf.placeholder(tf.float32, shape=[batch_no,label_patchs_size,label_patchs_size,label_patchs_size, 1])
    # label12 = tf.placeholder(tf.float32, shape=[batch_no,label_patchs_size,label_patchs_size,label_patchs_size, 1])
    # label13 = tf.placeholder(tf.float32, shape=[batch_no,label_patchs_size,label_patchs_size,label_patchs_size, 1])
    # label14 = tf.placeholder(tf.float32, shape=[batch_no,label_patchs_size,label_patchs_size,label_patchs_size, 1])
    img_row1 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row2 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row3 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row4 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row5 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row6 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row7 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row8 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])

    label1 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label2 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label3 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label4 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label5 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label6 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label7 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label8 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label9 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label10 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label11 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label12 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label13 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label14 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    is_training = tf.placeholder(tf.bool, name='is_training')
    input_dim = tf.placeholder(tf.int32, name='input_dim')
    # ave_huber = tf.placeholder(tf.float32, name='huber')

    forked_densenet = _forked_densenet()

    y, img_row1, img_row2, img_row3, img_row4, \
    img_row5, img_row6, img_row7, img_row8 = \
        forked_densenet.densenet(img_row1=img_row1, img_row2=img_row2, img_row3=img_row3, img_row4=img_row4,
                                 img_row5=img_row5,
                                 img_row6=img_row6, img_row7=img_row7, img_row8=img_row8, input_dim=input_dim,
                                 is_training=is_training)

    loss_instance = _loss_func()

    with tf.name_scope('averaged_mean_squared_error'):
        [
            _loss, _ssim, _huber, _ssim_angio, _ssim_perf, _huber_angio,
            _huber_perf, perf_loss, angio_loss
        ] = loss_instance.averaged_SSIM_huber(
            label1=label1,
            label2=label2,
            label3=label3,
            label4=label4,
            label5=label5,
            label6=label6,
            label7=label7,
            label8=label8,
            label9=label9,
            label10=label10,
            label11=label11,
            label12=label12,
            label13=label13,
            label14=label14,
            logit1=y[:, :, :, :, 0, np.newaxis],
            logit2=y[:, :, :, :, 1, np.newaxis],
            logit3=y[:, :, :, :, 2, np.newaxis],
            logit4=y[:, :, :, :, 3, np.newaxis],
            logit5=y[:, :, :, :, 4, np.newaxis],
            logit6=y[:, :, :, :, 5, np.newaxis],
            logit7=y[:, :, :, :, 6, np.newaxis],
            logit8=y[:, :, :, :, 7, np.newaxis],
            logit9=y[:, :, :, :, 8, np.newaxis],
            logit10=y[:, :, :, :, 9, np.newaxis],
            logit11=y[:, :, :, :, 10, np.newaxis],
            logit12=y[:, :, :, :, 11, np.newaxis],
            logit13=y[:, :, :, :, 12, np.newaxis],
            logit14=y[:, :, :, :, 13, np.newaxis])
        cost = tf.reduce_mean(_loss, name="cost")
        ssim_cost = tf.reduce_mean(_ssim, name="ssim_cost")
        huber_cost = tf.reduce_mean(_huber, name="huber_cost")

        ssim_angio = tf.reduce_mean(_ssim_angio, name="ssim_angio")
        ssim_perf = tf.reduce_mean(_ssim_perf, name="ssim_perf")
        huber_angio = tf.reduce_mean(_huber_angio, name="huber_angio")
        huber_perf = tf.reduce_mean(_huber_perf, name="huber_perf")

    # ========================================================================
    ave_loss = tf.placeholder(tf.float32, name='loss')
    ave_loss_perf = tf.placeholder(tf.float32, name='loss_perf')
    ave_loss_angio = tf.placeholder(tf.float32, name='loss_angio')

    average_gradient_perf = tf.placeholder(tf.float32, name='grad_ave_perf')
    average_gradient_angio = tf.placeholder(tf.float32, name='grad_ave_angio')

    ave_huber = tf.placeholder(tf.float32, name='huber')
    # restore the model
    sess = tf.Session()
    saver = tf.train.Saver()

    ckpt = tf.train.get_checkpoint_state(chckpnt_dir)
    saver.restore(sess, ckpt.model_checkpoint_path)

    copyfile('./test_synthesize_net2.py',
             result_path + '/test_synthesize_net2.py')

    _image_class = image_class(train_data,
                               bunch_of_images_no=1,
                               is_training=1,
                               patch_window=patch_window,
                               sample_no_per_bunch=1,
                               label_patch_size=label_patchs_size,
                               validation_total_sample=0)
    learning_rate = 1E-5
    extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    # with tf.control_dependencies(extra_update_ops):
    #     optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)
    # init = tf.global_variables_initializer()

    loss = 0
    for img_indx in range(len(test_set)):
        crush, noncrush, perf, angio, spacing, direction, origin = _image_class.read_image_for_test(
            test_set=test_set,
            img_indx=img_indx,
            input_size=in_dim,
            final_layer=final_layer)

        [
            loss_train1,
            out,
        ] = sess.run(
            [
                cost,
                y,
            ],
            feed_dict={
                img_row1:
                np.expand_dims(np.expand_dims(crush[0], axis=-1), axis=0),
                img_row2:
                np.expand_dims(np.expand_dims(noncrush[1], axis=-1), axis=0),
                img_row3:
                np.expand_dims(np.expand_dims(crush[2], axis=-1), axis=0),
                img_row4:
                np.expand_dims(np.expand_dims(noncrush[3], axis=-1), axis=0),
                img_row5:
                np.expand_dims(np.expand_dims(crush[4], axis=-1), axis=0),
                img_row6:
                np.expand_dims(np.expand_dims(noncrush[5], axis=-1), axis=0),
                img_row7:
                np.expand_dims(np.expand_dims(crush[6], axis=-1), axis=0),
                img_row8:
                np.expand_dims(np.expand_dims(noncrush[7], axis=-1), axis=0),
                label1:
                np.expand_dims(np.expand_dims(perf[0], axis=-1), axis=0),
                label2:
                np.expand_dims(np.expand_dims(perf[1], axis=-1), axis=0),
                label3:
                np.expand_dims(np.expand_dims(perf[2], axis=-1), axis=0),
                label4:
                np.expand_dims(np.expand_dims(perf[3], axis=-1), axis=0),
                label5:
                np.expand_dims(np.expand_dims(perf[4], axis=-1), axis=0),
                label6:
                np.expand_dims(np.expand_dims(perf[5], axis=-1), axis=0),
                label7:
                np.expand_dims(np.expand_dims(perf[6], axis=-1), axis=0),
                label8:
                np.expand_dims(np.expand_dims(angio[0], axis=-1), axis=0),
                label9:
                np.expand_dims(np.expand_dims(angio[1], axis=-1), axis=0),
                label10:
                np.expand_dims(np.expand_dims(angio[2], axis=-1), axis=0),
                label11:
                np.expand_dims(np.expand_dims(angio[3], axis=-1), axis=0),
                label12:
                np.expand_dims(np.expand_dims(angio[4], axis=-1), axis=0),
                label13:
                np.expand_dims(np.expand_dims(angio[5], axis=-1), axis=0),
                label14:
                np.expand_dims(np.expand_dims(angio[6], axis=-1), axis=0),
                is_training:
                False,
                input_dim:
                patch_window,
                ave_loss:
                -1,
                ave_loss_perf:
                -1,
                ave_loss_angio:
                -1,
                average_gradient_perf:
                -1,
                average_gradient_angio:
                -1
            })

        for i in range(np.shape(out)[-1]):
            image = out[0, :, :, :, i]
            sitk_image = sitk.GetImageFromArray(image)
            res_dir = test_set[img_indx][0][0].split('/')[-2]
            if i == 0 and test_vali == 1:
                os.mkdir(parent_path + 'results/' + res_dir)
            elif i == 0 and test_vali == 2:
                os.mkdir(parent_path + 'results_tr/' + res_dir)
            elif i == 0 and test_vali == 3:
                os.mkdir(parent_path + 'results_vali/' + res_dir)
            if i < 7:
                nm = 'perf'
            else:
                nm = 'angi'
            sitk_image.SetDirection(direction=direction)
            sitk_image.SetOrigin(origin=origin)
            sitk_image.SetSpacing(spacing=spacing)
            if test_vali == 1:
                sitk.WriteImage(
                    sitk_image, parent_path + 'results/' + res_dir + '/' + nm +
                    '_' + str(i % 7) + '.mha')
                print(parent_path + 'results/' + res_dir + '/' + nm + '_' +
                      str(i % 7) + '.mha done!')
            elif test_vali == 2:
                sitk.WriteImage(
                    sitk_image, parent_path + 'results_tr/' + res_dir + '/' +
                    nm + '_' + str(i % 7) + '.mha')
                print(parent_path + 'results_tr/' + res_dir + '/' + nm + '_' +
                      str(i % 7) + '.mha done!')
            else:
                sitk.WriteImage(
                    sitk_image, parent_path + 'results_vali/' + res_dir + '/' +
                    nm + '_' + str(i % 7) + '.mha')
                print(parent_path + 'results_vali/' + res_dir + '/' + nm +
                      '_' + str(i % 7) + '.mha done!')
        for i in range(7):
            if i == 0 and test_vali == 1:
                os.mkdir(parent_path + 'results/' + res_dir + '/GT/')
            if i == 0 and test_vali == 2:
                os.mkdir(parent_path + 'results_tr/' + res_dir + '/GT/')
            if i == 0 and test_vali == 3:
                os.mkdir(parent_path + 'results_vali/' + res_dir + '/GT/')
            sitk_angio = sitk.GetImageFromArray(angio[i])
            sitk_angio.SetDirection(direction=direction)
            sitk_angio.SetOrigin(origin=origin)
            sitk_angio.SetSpacing(spacing=spacing)

            sitk_perf = sitk.GetImageFromArray(perf[i])
            sitk_perf.SetDirection(direction=direction)
            sitk_perf.SetOrigin(origin=origin)
            sitk_perf.SetSpacing(spacing=spacing)

            if test_vali == 1:
                sitk.WriteImage(
                    sitk_angio, parent_path + 'results/' + res_dir +
                    '/GT/angio_' + str(i) + '.mha')
                sitk.WriteImage(
                    sitk_perf, parent_path + 'results/' + res_dir +
                    '/GT/perf_' + str(i) + '.mha')
            elif test_vali == 2:
                sitk.WriteImage(
                    sitk_angio, parent_path + 'results_tr/' + res_dir +
                    '/GT/angio_' + str(i) + '.mha')
                sitk.WriteImage(
                    sitk_perf, parent_path + 'results_tr/' + res_dir +
                    '/GT/perf_' + str(i) + '.mha')
            else:
                sitk.WriteImage(
                    sitk_angio, parent_path + 'results_vali/' + res_dir +
                    '/GT/angio_' + str(i) + '.mha')
                sitk.WriteImage(
                    sitk_perf, parent_path + 'results_vali/' + res_dir +
                    '/GT/perf_' + str(i) + '.mha')

        a = 1

        # plt.imshow(out[0, int(gt_cube_size / 2), :, :, 0])
        # plt.figure()
        loss += loss_train1
        print('Loss_train: ', loss_train1)

    print('Total loss: ', loss / len(test_set))
コード例 #7
0
def test_all_nets(num, test_set):

    chckpnt_dir = parent_path + 'unet_checkpoints/'

    img_row1 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row2 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row3 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row4 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row5 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row6 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row7 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row8 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    mri_ph = tf.placeholder(tf.float32,
                            shape=[None, None, None, None, 1],
                            name='mri')
    label1 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label2 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label3 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label4 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label5 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label6 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label7 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label8 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label9 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label10 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label11 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label12 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label13 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label14 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])

    all_loss = tf.placeholder(tf.float32, name='loss')
    is_training = tf.placeholder(tf.bool, name='is_training')
    input_dim = tf.placeholder(tf.int32, name='input_dim')

    multi_stage_densenet = _multi_stage_densenet()
    y, loss_upsampling11, loss_upsampling22 = multi_stage_densenet.multi_stage_densenet(
        img_row1=img_row1,
        img_row2=img_row2,
        img_row3=img_row3,
        img_row4=img_row4,
        img_row5=img_row5,
        img_row6=img_row6,
        img_row7=img_row7,
        img_row8=img_row8,
        input_dim=input_dim,
        mri=mri_ph,
        is_training=is_training)

    loss_instance = _loss_func()
    labels = []
    labels.append(label1)
    labels.append(label2)
    labels.append(label3)
    labels.append(label4)
    labels.append(label5)
    labels.append(label6)
    labels.append(label7)
    labels.append(label8)
    labels.append(label9)
    labels.append(label10)
    labels.append(label11)
    labels.append(label12)
    labels.append(label13)
    labels.append(label14)

    logits = []
    logits.append(y[:, :, :, :, 0, np.newaxis])
    logits.append(y[:, :, :, :, 1, np.newaxis])
    logits.append(y[:, :, :, :, 2, np.newaxis])
    logits.append(y[:, :, :, :, 3, np.newaxis])
    logits.append(y[:, :, :, :, 4, np.newaxis])
    logits.append(y[:, :, :, :, 5, np.newaxis])
    logits.append(y[:, :, :, :, 6, np.newaxis])
    logits.append(y[:, :, :, :, 7, np.newaxis])
    logits.append(y[:, :, :, :, 8, np.newaxis])
    logits.append(y[:, :, :, :, 9, np.newaxis])
    logits.append(y[:, :, :, :, 10, np.newaxis])
    logits.append(y[:, :, :, :, 11, np.newaxis])
    logits.append(y[:, :, :, :, 12, np.newaxis])
    logits.append(y[:, :, :, :, 13, np.newaxis])
    stage1 = []
    stage1.append(loss_upsampling11[:, :, :, :, 0, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 1, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 2, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 3, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 4, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 5, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 6, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 7, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 8, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 9, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 10, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 11, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 12, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 13, np.newaxis])

    stage2 = []
    stage2.append(loss_upsampling22[:, :, :, :, 0, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 1, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 2, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 3, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 4, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 5, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 6, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 7, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 8, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 9, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 10, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 11, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 12, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 13, np.newaxis])

    with tf.name_scope('Loss'):
        loss_dic = loss_instance.loss_selector(
            'Multistage_ssim_perf_angio_loss',
            labels=labels,
            logits=logits,
            stage1=stage1,
            stage2=stage2)
        cost = tf.reduce_mean(loss_dic["loss"], name="cost")

    sess = tf.Session()
    saver = tf.train.Saver()

    ckpt = tf.train.get_checkpoint_state(chckpnt_dir)

    model_path = ckpt.model_checkpoint_path.rsplit(
        '/',
        1)[0] + '/unet_inter_epoch0_point' + str(num) + '.ckpt-' + str(num)
    saver.restore(sess, model_path)

    # copyfile('./test_synthesize_multistage_perf_angio.py', result_path + '/test_synthesize_multistage_perf_angio.py')

    _image_class = image_class(train_data,
                               bunch_of_images_no=1,
                               is_training=1,
                               patch_window=patch_window,
                               sample_no_per_bunch=1,
                               label_patch_size=label_patchs_size,
                               validation_total_sample=0)
    learning_rate = 1E-5
    extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(extra_update_ops):
        optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)
        # init = tf.global_variables_initializer()

    loss = 0
    for img_indx in range(len(test_set[0:1])):
        crush, noncrush, perf, angio, mri, segmentation_, spacing, direction, origin = _image_class.read_image_for_test(
            test_set=test_set,
            img_indx=img_indx,
            input_size=in_dim,
            final_layer=final_layer)

        [out] = sess.run(
            [y],
            feed_dict={
                img_row1: np.expand_dims(np.expand_dims(crush[0], 0), -1),
                img_row2: np.expand_dims(np.expand_dims(noncrush[1], 0), -1),
                img_row3: np.expand_dims(np.expand_dims(crush[2], 0), -1),
                img_row4: np.expand_dims(np.expand_dims(noncrush[3], 0), -1),
                img_row5: np.expand_dims(np.expand_dims(crush[4], 0), -1),
                img_row6: np.expand_dims(np.expand_dims(noncrush[5], 0), -1),
                img_row7: np.expand_dims(np.expand_dims(crush[6], 0), -1),
                img_row8: np.expand_dims(np.expand_dims(noncrush[7], 0), -1),
                mri_ph: np.expand_dims(np.expand_dims(mri, 0), -1),
                label1: np.expand_dims(np.expand_dims(perf[0], 0), -1),
                label2: np.expand_dims(np.expand_dims(perf[1], 0), -1),
                label3: np.expand_dims(np.expand_dims(perf[2], 0), -1),
                label4: np.expand_dims(np.expand_dims(perf[3], 0), -1),
                label5: np.expand_dims(np.expand_dims(perf[4], 0), -1),
                label6: np.expand_dims(np.expand_dims(perf[5], 0), -1),
                label7: np.expand_dims(np.expand_dims(perf[6], 0), -1),
                label8: np.expand_dims(np.expand_dims(angio[0], 0), -1),
                label9: np.expand_dims(np.expand_dims(angio[1], 0), -1),
                label10: np.expand_dims(np.expand_dims(angio[2], 0), -1),
                label11: np.expand_dims(np.expand_dims(angio[3], 0), -1),
                label12: np.expand_dims(np.expand_dims(angio[4], 0), -1),
                label13: np.expand_dims(np.expand_dims(angio[5], 0), -1),
                label14: np.expand_dims(np.expand_dims(angio[6], 0), -1),
                is_training: False,
                input_dim: patch_window,
                all_loss: -1.,
            })

        for i in range(np.shape(out)[-1]):
            image = out[0, :, :, :, i]
            gt = sitk.GetArrayFromImage(sitk.GetImageFromArray(perf[i]))
            dic = analysis.analysis(result=image, gt=gt, min=0, max=max_perf)
            print(dic)

        # for i in range(np.shape(out)[-1]):
        #     image = out[0, :, :, :, i]
        #     sitk_image = sitk.GetImageFromArray(image)
        #     res_dir = test_set[img_indx][0][0].split('/')[-2]
        #     if i == 0:
        #         os.mkdir(parent_path + 'results/' + res_dir)
        #     if i < 7:
        #         nm = 'perf'
        #     else:
        #         nm = 'angi'
        #     sitk_image.SetDirection(direction=direction)
        #     sitk_image.SetOrigin(origin=origin)
        #     sitk_image.SetSpacing(spacing=spacing)
        #     sitk.WriteImage(sitk_image, parent_path + 'results/' + res_dir + '/' + nm + '_' + str(i % 7) + '.mha')
        #     print(parent_path + 'results/' + res_dir + '/' + nm + '_' + str(i % 7) + '.mha done!')
        # for i in range(7):
        #     if i == 0:
        #         os.mkdir(parent_path + 'results/' + res_dir + '/GT/')
        #     sitk_angio = sitk.GetImageFromArray(angio[i])
        #     sitk_angio.SetDirection(direction=direction)
        #     sitk_angio.SetOrigin(origin=origin)
        #     sitk_angio.SetSpacing(spacing=spacing)
        #     sitk.WriteImage(sitk_angio, parent_path + 'results/' + res_dir + '/GT/angio_' + str(i) + '.mha')
        #
        #     sitk_perf = sitk.GetImageFromArray(perf[i])
        #     sitk_perf.SetDirection(direction=direction)
        #     sitk_perf.SetOrigin(origin=origin)
        #     sitk_perf.SetSpacing(spacing=spacing)
        #     sitk.WriteImage(sitk_perf, parent_path + 'results/' + res_dir + '/GT/perf_' + str(i) + '.mha')
        a = 1

        # plt.imshow(out[0, int(gt_cube_size / 2), :, :, 0])
        # plt.figure()
        # loss += loss_train1
        # print('Loss_train: ', loss_train1)

    print('Total loss: ', loss / len(test_set))
コード例 #8
0
def test_all_nets():
    data = 2

    Server = 'DL'

    if Server == 'DL':
        parent_path = '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/sythesize_code/ASL_LOG/multi_stage/experiment-2/'
        data_path = '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Data-01/BrainWeb_permutation00_low/'
    else:
        parent_path = '/exports/lkeb-hpc/syousefi/Code/'
        data_path = '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Data-01/BrainWeb_permutation00_low/'

    img_name = ''
    label_name = ''

    _rd = _read_data(data=data,
                     reverse=False,
                     img_name=img_name,
                     label_name=label_name,
                     dataset_path=data_path)
    '''read path of the images for train, test, and validation'''
    train_data, validation_data, test_data = _rd.read_data_path()

    chckpnt_dir = parent_path + 'unet_checkpoints/'
    result_path = parent_path + 'results/'
    batch_no = 1
    batch_no_validation = batch_no
    # label_patchs_size = 87#39  # 63
    # patch_window = 103#53  # 77#89
    if test_vali == 1:
        test_set = validation_data
    else:
        test_set = test_data
    # ===================================================================================
    # img_row1 = tf.placeholder(tf.float32,
    #                           shape=[batch_no, patch_window, patch_window, patch_window, 1],
    #                           name='img_row1')
    # img_row2 = tf.placeholder(tf.float32,
    #                           shape=[batch_no, patch_window, patch_window, patch_window, 1],
    #                           name='img_row2')
    # img_row3 = tf.placeholder(tf.float32,
    #                           shape=[batch_no, patch_window, patch_window, patch_window, 1],
    #                           name='img_row3')
    # img_row4 = tf.placeholder(tf.float32,
    #                           shape=[batch_no, patch_window, patch_window, patch_window, 1],
    #                           name='img_row4')
    # img_row5 = tf.placeholder(tf.float32,
    #                           shape=[batch_no, patch_window, patch_window, patch_window, 1],
    #                           name='img_row5')
    # img_row6 = tf.placeholder(tf.float32,
    #                           shape=[batch_no, patch_window, patch_window, patch_window, 1],
    #                           name='img_row6')
    # img_row7 = tf.placeholder(tf.float32,
    #                           shape=[batch_no, patch_window, patch_window, patch_window, 1],
    #                           name='img_row7')
    # img_row8 = tf.placeholder(tf.float32,
    #                           shape=[batch_no, patch_window, patch_window, patch_window, 1],
    #                           name='img_row8')
    #
    # mri_ph = tf.placeholder(tf.float32,
    #                         shape=[batch_no, patch_window, patch_window, patch_window, 1],
    #                         name='mri')
    #
    # segmentation = tf.placeholder(tf.float32,
    #                               shape=[batch_no, label_patchs_size, label_patchs_size,
    #                                      label_patchs_size, 1], name='segments')
    #
    # label1 = tf.placeholder(tf.float32, shape=[batch_no, label_patchs_size, label_patchs_size,
    #                                            label_patchs_size, 1], name='label1')
    # label2 = tf.placeholder(tf.float32, shape=[batch_no, label_patchs_size, label_patchs_size,
    #                                            label_patchs_size, 1], name='label2')
    # label3 = tf.placeholder(tf.float32, shape=[batch_no, label_patchs_size, label_patchs_size,
    #                                            label_patchs_size, 1], name='label3')
    # label4 = tf.placeholder(tf.float32, shape=[batch_no, label_patchs_size, label_patchs_size,
    #                                            label_patchs_size, 1], name='label4')
    # label5 = tf.placeholder(tf.float32, shape=[batch_no, label_patchs_size, label_patchs_size,
    #                                            label_patchs_size, 1], name='label5')
    # label6 = tf.placeholder(tf.float32, shape=[batch_no, label_patchs_size, label_patchs_size,
    #                                            label_patchs_size, 1], name='label6')
    # label7 = tf.placeholder(tf.float32, shape=[batch_no, label_patchs_size, label_patchs_size,
    #                                            label_patchs_size, 1], name='label7')
    # label8 = tf.placeholder(tf.float32, shape=[batch_no, label_patchs_size, label_patchs_size,
    #                                            label_patchs_size, 1], name='label8')
    # label9 = tf.placeholder(tf.float32, shape=[batch_no, label_patchs_size, label_patchs_size,
    #                                            label_patchs_size, 1], name='label9')
    # label10 = tf.placeholder(tf.float32, shape=[batch_no, label_patchs_size, label_patchs_size,
    #                                             label_patchs_size, 1], name='label10')
    # label11 = tf.placeholder(tf.float32, shape=[batch_no, label_patchs_size, label_patchs_size,
    #                                             label_patchs_size, 1], name='label11')
    # label12 = tf.placeholder(tf.float32, shape=[batch_no, label_patchs_size, label_patchs_size,
    #                                             label_patchs_size, 1], name='label12')
    # label13 = tf.placeholder(tf.float32, shape=[batch_no, label_patchs_size, label_patchs_size,
    #                                             label_patchs_size, 1], name='label13')
    # label14 = tf.placeholder(tf.float32, shape=[batch_no, label_patchs_size, label_patchs_size,
    #                                             label_patchs_size, 1], name='label14')

    img_row1 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row2 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row3 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row4 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row5 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row6 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row7 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    img_row8 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    mri_ph = tf.placeholder(tf.float32,
                            shape=[None, None, None, None, 1],
                            name='mri')
    # segmentation = tf.placeholder(tf.float32, shape=[None, None, None, None, 1], name='segmentation')
    label1 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label2 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label3 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label4 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label5 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label6 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label7 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label8 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label9 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label10 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label11 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label12 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label13 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    label14 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])

    all_loss = tf.placeholder(tf.float32, name='loss')
    is_training = tf.placeholder(tf.bool, name='is_training')
    input_dim = tf.placeholder(tf.int32, name='input_dim')
    # ave_huber = tf.placeholder(tf.float32, name='huber')

    multi_stage_densenet = _multi_stage_densenet()
    y, loss_upsampling11, loss_upsampling22 = multi_stage_densenet.multi_stage_densenet(
        img_row1=img_row1,
        img_row2=img_row2,
        img_row3=img_row3,
        img_row4=img_row4,
        img_row5=img_row5,
        img_row6=img_row6,
        img_row7=img_row7,
        img_row8=img_row8,
        input_dim=input_dim,
        mri=mri_ph,
        is_training=is_training)

    loss_instance = _loss_func()
    labels = []
    labels.append(label1)
    labels.append(label2)
    labels.append(label3)
    labels.append(label4)
    labels.append(label5)
    labels.append(label6)
    labels.append(label7)
    labels.append(label8)
    labels.append(label9)
    labels.append(label10)
    labels.append(label11)
    labels.append(label12)
    labels.append(label13)
    labels.append(label14)

    logits = []
    logits.append(y[:, :, :, :, 0, np.newaxis])
    logits.append(y[:, :, :, :, 1, np.newaxis])
    logits.append(y[:, :, :, :, 2, np.newaxis])
    logits.append(y[:, :, :, :, 3, np.newaxis])
    logits.append(y[:, :, :, :, 4, np.newaxis])
    logits.append(y[:, :, :, :, 5, np.newaxis])
    logits.append(y[:, :, :, :, 6, np.newaxis])
    logits.append(y[:, :, :, :, 7, np.newaxis])
    logits.append(y[:, :, :, :, 8, np.newaxis])
    logits.append(y[:, :, :, :, 9, np.newaxis])
    logits.append(y[:, :, :, :, 10, np.newaxis])
    logits.append(y[:, :, :, :, 11, np.newaxis])
    logits.append(y[:, :, :, :, 12, np.newaxis])
    logits.append(y[:, :, :, :, 13, np.newaxis])
    stage1 = []
    stage1.append(loss_upsampling11[:, :, :, :, 0, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 1, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 2, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 3, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 4, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 5, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 6, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 7, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 8, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 9, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 10, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 11, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 12, np.newaxis])
    stage1.append(loss_upsampling11[:, :, :, :, 13, np.newaxis])

    stage2 = []
    stage2.append(loss_upsampling22[:, :, :, :, 0, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 1, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 2, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 3, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 4, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 5, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 6, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 7, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 8, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 9, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 10, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 11, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 12, np.newaxis])
    stage2.append(loss_upsampling22[:, :, :, :, 13, np.newaxis])

    with tf.name_scope('Loss'):
        loss_dic = loss_instance.loss_selector(
            'Multistage_ssim_perf_angio_loss',
            labels=labels,
            logits=logits,
            stage1=stage1,
            stage2=stage2)
        cost = tf.reduce_mean(loss_dic["loss"], name="cost")
        # cost_angio = tf.reduce_mean(loss_dic["angio_SSIM"], name="angio_SSIM")
        # cost_perf = tf.reduce_mean(loss_dic["perf_SSIM"], name="perf_SSIM")

    # ========================================================================
    # ave_loss = tf.placeholder(tf.float32, name='loss')
    # ave_loss_perf = tf.placeholder(tf.float32, name='loss_perf')
    # ave_loss_angio = tf.placeholder(tf.float32, name='loss_angio')
    #
    # average_gradient_perf = tf.placeholder(tf.float32, name='grad_ave_perf')
    # average_gradient_angio = tf.placeholder(tf.float32, name='grad_ave_angio')
    #
    # ave_huber = tf.placeholder(tf.float32, name='huber')
    # restore the model
    sess = tf.Session()
    saver = tf.train.Saver()

    ckpt = tf.train.get_checkpoint_state(chckpnt_dir)
    saver.restore(sess, ckpt.model_checkpoint_path)

    _image_class = image_class(train_data,
                               bunch_of_images_no=1,
                               is_training=1,
                               patch_window=patch_window,
                               sample_no_per_bunch=1,
                               label_patch_size=label_patchs_size,
                               validation_total_sample=0)
    learning_rate = 1E-5
    extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(extra_update_ops):
        optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)
        # init = tf.global_variables_initializer()

    loss = 0
    Elapsed = []
    for img_indx in range(len(test_set)):
        crush, noncrush, perf, angio, mri, segmentation_, spacing, direction, origin = _image_class.read_image_for_test(
            test_set=test_set,
            img_indx=img_indx,
            input_size=in_dim,
            final_layer=final_layer)
        t = time.time()

        [out] = sess.run(
            [y],
            feed_dict={
                img_row1: np.expand_dims(np.expand_dims(crush[0], 0), -1),
                img_row2: np.expand_dims(np.expand_dims(noncrush[1], 0), -1),
                img_row3: np.expand_dims(np.expand_dims(crush[2], 0), -1),
                img_row4: np.expand_dims(np.expand_dims(noncrush[3], 0), -1),
                img_row5: np.expand_dims(np.expand_dims(crush[4], 0), -1),
                img_row6: np.expand_dims(np.expand_dims(noncrush[5], 0), -1),
                img_row7: np.expand_dims(np.expand_dims(crush[6], 0), -1),
                img_row8: np.expand_dims(np.expand_dims(noncrush[7], 0), -1),
                mri_ph: np.expand_dims(np.expand_dims(mri, 0), -1),
                label1: np.expand_dims(np.expand_dims(perf[0], 0), -1),
                label2: np.expand_dims(np.expand_dims(perf[1], 0), -1),
                label3: np.expand_dims(np.expand_dims(perf[2], 0), -1),
                label4: np.expand_dims(np.expand_dims(perf[3], 0), -1),
                label5: np.expand_dims(np.expand_dims(perf[4], 0), -1),
                label6: np.expand_dims(np.expand_dims(perf[5], 0), -1),
                label7: np.expand_dims(np.expand_dims(perf[6], 0), -1),
                label8: np.expand_dims(np.expand_dims(angio[0], 0), -1),
                label9: np.expand_dims(np.expand_dims(angio[1], 0), -1),
                label10: np.expand_dims(np.expand_dims(angio[2], 0), -1),
                label11: np.expand_dims(np.expand_dims(angio[3], 0), -1),
                label12: np.expand_dims(np.expand_dims(angio[4], 0), -1),
                label13: np.expand_dims(np.expand_dims(angio[5], 0), -1),
                label14: np.expand_dims(np.expand_dims(angio[6], 0), -1),
                is_training: False,
                input_dim: patch_window,
                all_loss: -1.,
            })
        elapsed = time.time() - t
        Elapsed.append(elapsed)
        print(elapsed)

    print('MEAN:')
    print(np.mean(Elapsed))
    print('STD:')
    print(np.std(Elapsed))

    print('Total loss: ', loss / len(test_set))
コード例 #9
0
def test_all_nets():
    data = 2

    Server = 'DL'
    if Server == 'DL':
        parent_path = '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/sythesize_code/ASL_LOG/Log_perceptual/regularization/perceptual-0/'
        data_path = '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Data-01/BrainWeb_permutation00_low/'
    else:
        parent_path = '/exports/lkeb-hpc/syousefi/Code/'
        data_path = '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Data-01/BrainWeb_permutation00_low/'

    img_name = ''
    label_name = ''

    _rd = _read_data(data=data,
                     reverse=False,
                     img_name=img_name,
                     label_name=label_name,
                     dataset_path=data_path)
    '''read path of the images for train, test, and validation'''
    train_data, validation_data, test_data = _rd.read_data_path()

    chckpnt_dir = parent_path + 'unet_checkpoints/'
    result_path = parent_path + 'results/'
    batch_no = 1
    batch_no_validation = batch_no
    # label_patchs_size = 87#39  # 63
    # patch_window = 103#53  # 77#89
    if test_vali == 1:
        test_set = validation_data
    else:
        test_set = test_data
    # ===================================================================================
    img_row1 = tf.placeholder(
        tf.float32,
        shape=[batch_no, patch_window, patch_window, patch_window, 1],
        name='img_row1')
    img_row2 = tf.placeholder(
        tf.float32,
        shape=[batch_no, patch_window, patch_window, patch_window, 1],
        name='img_row2')
    img_row3 = tf.placeholder(
        tf.float32,
        shape=[batch_no, patch_window, patch_window, patch_window, 1],
        name='img_row3')
    img_row4 = tf.placeholder(
        tf.float32,
        shape=[batch_no, patch_window, patch_window, patch_window, 1],
        name='img_row4')
    img_row5 = tf.placeholder(
        tf.float32,
        shape=[batch_no, patch_window, patch_window, patch_window, 1],
        name='img_row5')
    img_row6 = tf.placeholder(
        tf.float32,
        shape=[batch_no, patch_window, patch_window, patch_window, 1],
        name='img_row6')
    img_row7 = tf.placeholder(
        tf.float32,
        shape=[batch_no, patch_window, patch_window, patch_window, 1],
        name='img_row7')
    img_row8 = tf.placeholder(
        tf.float32,
        shape=[batch_no, patch_window, patch_window, patch_window, 1],
        name='img_row8')

    mri_ph = tf.placeholder(
        tf.float32,
        shape=[batch_no, patch_window, patch_window, patch_window, 1],
        name='mri')

    label1 = tf.placeholder(tf.float32,
                            shape=[
                                batch_no, label_patchs_size, label_patchs_size,
                                label_patchs_size, 1
                            ],
                            name='label1')
    label2 = tf.placeholder(tf.float32,
                            shape=[
                                batch_no, label_patchs_size, label_patchs_size,
                                label_patchs_size, 1
                            ],
                            name='label2')
    label3 = tf.placeholder(tf.float32,
                            shape=[
                                batch_no, label_patchs_size, label_patchs_size,
                                label_patchs_size, 1
                            ],
                            name='label3')
    label4 = tf.placeholder(tf.float32,
                            shape=[
                                batch_no, label_patchs_size, label_patchs_size,
                                label_patchs_size, 1
                            ],
                            name='label4')
    label5 = tf.placeholder(tf.float32,
                            shape=[
                                batch_no, label_patchs_size, label_patchs_size,
                                label_patchs_size, 1
                            ],
                            name='label5')
    label6 = tf.placeholder(tf.float32,
                            shape=[
                                batch_no, label_patchs_size, label_patchs_size,
                                label_patchs_size, 1
                            ],
                            name='label6')
    label7 = tf.placeholder(tf.float32,
                            shape=[
                                batch_no, label_patchs_size, label_patchs_size,
                                label_patchs_size, 1
                            ],
                            name='label7')
    label8 = tf.placeholder(tf.float32,
                            shape=[
                                batch_no, label_patchs_size, label_patchs_size,
                                label_patchs_size, 1
                            ],
                            name='label8')
    label9 = tf.placeholder(tf.float32,
                            shape=[
                                batch_no, label_patchs_size, label_patchs_size,
                                label_patchs_size, 1
                            ],
                            name='label9')
    label10 = tf.placeholder(tf.float32,
                             shape=[
                                 batch_no, label_patchs_size,
                                 label_patchs_size, label_patchs_size, 1
                             ],
                             name='label10')
    label11 = tf.placeholder(tf.float32,
                             shape=[
                                 batch_no, label_patchs_size,
                                 label_patchs_size, label_patchs_size, 1
                             ],
                             name='label11')
    label12 = tf.placeholder(tf.float32,
                             shape=[
                                 batch_no, label_patchs_size,
                                 label_patchs_size, label_patchs_size, 1
                             ],
                             name='label12')
    label13 = tf.placeholder(tf.float32,
                             shape=[
                                 batch_no, label_patchs_size,
                                 label_patchs_size, label_patchs_size, 1
                             ],
                             name='label13')
    label14 = tf.placeholder(tf.float32,
                             shape=[
                                 batch_no, label_patchs_size,
                                 label_patchs_size, label_patchs_size, 1
                             ],
                             name='label14')

    is_training = tf.placeholder(tf.bool, name='is_training')
    input_dim = tf.placeholder(tf.int32, name='input_dim')

    perf_vgg_loss_tens = tf.placeholder(tf.float32, name='VGG_perf')
    angio_vgg_loss_tens = tf.placeholder(tf.float32, name='VGG_angio')
    perf_vgg_tens0 = tf.placeholder(tf.float32, name='vgg_perf0')
    perf_vgg_tens1 = tf.placeholder(tf.float32, name='vgg_perf1')
    perf_vgg_tens2 = tf.placeholder(tf.float32, name='vgg_perf2')
    perf_vgg_tens3 = tf.placeholder(tf.float32, name='vgg_perf3')
    perf_vgg_tens4 = tf.placeholder(tf.float32, name='vgg_perf4')
    perf_vgg_tens5 = tf.placeholder(tf.float32, name='vgg_perf5')
    perf_vgg_tens6 = tf.placeholder(tf.float32, name='vgg_perf6')

    angio_vgg_tens0 = tf.placeholder(tf.float32, name='vgg_angio0')
    angio_vgg_tens1 = tf.placeholder(tf.float32, name='vgg_angio1')
    angio_vgg_tens2 = tf.placeholder(tf.float32, name='vgg_angio2')
    angio_vgg_tens3 = tf.placeholder(tf.float32, name='vgg_angio3')
    angio_vgg_tens4 = tf.placeholder(tf.float32, name='vgg_angio4')
    angio_vgg_tens5 = tf.placeholder(tf.float32, name='vgg_angio5')
    angio_vgg_tens6 = tf.placeholder(tf.float32, name='vgg_angio6')
    perf_huber_loss_tens = tf.placeholder(tf.float32, name='huber_perf')
    angio_huber_loss_tens = tf.placeholder(tf.float32, name='huber_angio')

    perf_huber_tens0 = tf.placeholder(tf.float32, name='huber_perf0')
    perf_huber_tens1 = tf.placeholder(tf.float32, name='huber_perf1')
    perf_huber_tens2 = tf.placeholder(tf.float32, name='huber_perf2')
    perf_huber_tens3 = tf.placeholder(tf.float32, name='huber_perf3')
    perf_huber_tens4 = tf.placeholder(tf.float32, name='huber_perf4')
    perf_huber_tens5 = tf.placeholder(tf.float32, name='huber_perf5')
    perf_huber_tens6 = tf.placeholder(tf.float32, name='huber_perf6')

    angio_huber_tens0 = tf.placeholder(tf.float32, name='huber_angio0')
    angio_huber_tens1 = tf.placeholder(tf.float32, name='huber_angio1')
    angio_huber_tens2 = tf.placeholder(tf.float32, name='huber_angio2')
    angio_huber_tens3 = tf.placeholder(tf.float32, name='huber_angio3')
    angio_huber_tens4 = tf.placeholder(tf.float32, name='huber_angio4')
    angio_huber_tens5 = tf.placeholder(tf.float32, name='huber_angio5')
    angio_huber_tens6 = tf.placeholder(tf.float32, name='huber_angio6')
    # ===================================================================================
    densenet = _densenet()

    [y, _] = densenet.densenet(img_row1=img_row1,
                               img_row2=img_row2,
                               img_row3=img_row3,
                               img_row4=img_row4,
                               img_row5=img_row5,
                               img_row6=img_row6,
                               img_row7=img_row7,
                               img_row8=img_row8,
                               input_dim=input_dim,
                               is_training=is_training)
    vgg = vgg_feature_maker(test=1)
    feature_type = 'huber'
    vgg_y0 = vgg.feed_img(y[:, :, :, :, 0], feature_type=feature_type).copy()
    vgg_y1 = vgg.feed_img(y[:, :, :, :, 1], feature_type=feature_type).copy()
    vgg_y2 = vgg.feed_img(y[:, :, :, :, 2], feature_type=feature_type).copy()
    vgg_y3 = vgg.feed_img(y[:, :, :, :, 3], feature_type=feature_type).copy()
    vgg_y4 = vgg.feed_img(y[:, :, :, :, 4], feature_type=feature_type).copy()
    vgg_y5 = vgg.feed_img(y[:, :, :, :, 5], feature_type=feature_type).copy()
    vgg_y6 = vgg.feed_img(y[:, :, :, :, 6], feature_type=feature_type).copy()

    vgg_y7 = vgg.feed_img(y[:, :, :, :, 7], feature_type=feature_type)
    vgg_y8 = vgg.feed_img(y[:, :, :, :, 8], feature_type=feature_type)
    vgg_y9 = vgg.feed_img(y[:, :, :, :, 9], feature_type=feature_type)
    vgg_y10 = vgg.feed_img(y[:, :, :, :, 10], feature_type=feature_type)
    vgg_y11 = vgg.feed_img(y[:, :, :, :, 11], feature_type=feature_type)
    vgg_y12 = vgg.feed_img(y[:, :, :, :, 12], feature_type=feature_type)
    vgg_y13 = vgg.feed_img(y[:, :, :, :, 13], feature_type=feature_type)

    vgg_label0 = vgg.feed_img(label1[:, :, :, :, 0],
                              feature_type=feature_type).copy()
    vgg_label1 = vgg.feed_img(label2[:, :, :, :, 0],
                              feature_type=feature_type).copy()
    vgg_label2 = vgg.feed_img(label3[:, :, :, :, 0],
                              feature_type=feature_type).copy()
    vgg_label3 = vgg.feed_img(label4[:, :, :, :, 0],
                              feature_type=feature_type).copy()
    vgg_label4 = vgg.feed_img(label5[:, :, :, :, 0],
                              feature_type=feature_type).copy()
    vgg_label5 = vgg.feed_img(label6[:, :, :, :, 0],
                              feature_type=feature_type).copy()
    vgg_label6 = vgg.feed_img(label7[:, :, :, :, 0],
                              feature_type=feature_type).copy()

    vgg_label7 = vgg.feed_img(label8[:, :, :, :, 0], feature_type=feature_type)
    vgg_label8 = vgg.feed_img(label9[:, :, :, :, 0], feature_type=feature_type)
    vgg_label9 = vgg.feed_img(label10[:, :, :, :, 0],
                              feature_type=feature_type)
    vgg_label10 = vgg.feed_img(label11[:, :, :, :, 0],
                               feature_type=feature_type)
    vgg_label11 = vgg.feed_img(label12[:, :, :, :, 0],
                               feature_type=feature_type)
    vgg_label12 = vgg.feed_img(label13[:, :, :, :, 0],
                               feature_type=feature_type)
    vgg_label13 = vgg.feed_img(label14[:, :, :, :, 0],
                               feature_type=feature_type)

    all_loss = tf.placeholder(tf.float32, name='loss')
    # is_training = tf.placeholder(tf.bool, name='is_training')
    # input_dim = tf.placeholder(tf.int32, name='input_dim')
    # ave_huber = tf.placeholder(tf.float32, name='huber')

    labels = []
    labels.append(label1)
    labels.append(label2)
    labels.append(label3)
    labels.append(label4)
    labels.append(label5)
    labels.append(label6)
    labels.append(label7)

    labels.append(label8)
    labels.append(label9)
    labels.append(label10)
    labels.append(label11)
    labels.append(label12)
    labels.append(label13)
    labels.append(label14)

    logits = []
    logits.append(y[:, :, :, :, 0, np.newaxis])
    logits.append(y[:, :, :, :, 1, np.newaxis])
    logits.append(y[:, :, :, :, 2, np.newaxis])
    logits.append(y[:, :, :, :, 3, np.newaxis])
    logits.append(y[:, :, :, :, 4, np.newaxis])
    logits.append(y[:, :, :, :, 5, np.newaxis])
    logits.append(y[:, :, :, :, 6, np.newaxis])

    logits.append(y[:, :, :, :, 7, np.newaxis])
    logits.append(y[:, :, :, :, 8, np.newaxis])
    logits.append(y[:, :, :, :, 9, np.newaxis])
    logits.append(y[:, :, :, :, 10, np.newaxis])
    logits.append(y[:, :, :, :, 11, np.newaxis])
    logits.append(y[:, :, :, :, 12, np.newaxis])
    logits.append(y[:, :, :, :, 13, np.newaxis])

    loss_instance = _loss_func()
    vgg_in_feature = []
    vgg_in_feature.append(vgg_y0)
    vgg_in_feature.append(vgg_y1)
    vgg_in_feature.append(vgg_y2)
    vgg_in_feature.append(vgg_y3)
    vgg_in_feature.append(vgg_y4)
    vgg_in_feature.append(vgg_y5)
    vgg_in_feature.append(vgg_y6)

    vgg_in_feature.append(vgg_y7)
    vgg_in_feature.append(vgg_y8)
    vgg_in_feature.append(vgg_y9)
    vgg_in_feature.append(vgg_y10)
    vgg_in_feature.append(vgg_y11)
    vgg_in_feature.append(vgg_y12)
    vgg_in_feature.append(vgg_y13)

    vgg_label_feature = []
    vgg_label_feature.append(vgg_label0)
    vgg_label_feature.append(vgg_label1)
    vgg_label_feature.append(vgg_label2)
    vgg_label_feature.append(vgg_label3)
    vgg_label_feature.append(vgg_label4)
    vgg_label_feature.append(vgg_label5)
    vgg_label_feature.append(vgg_label6)

    vgg_label_feature.append(vgg_label7)
    vgg_label_feature.append(vgg_label8)
    vgg_label_feature.append(vgg_label9)
    vgg_label_feature.append(vgg_label10)
    vgg_label_feature.append(vgg_label11)
    vgg_label_feature.append(vgg_label12)
    vgg_label_feature.append(vgg_label13)
    with tf.name_scope('Loss'):
        loss_dic = loss_instance.loss_selector(
            'content_vgg_pairwise_loss_huber',
            labels=vgg_label_feature,
            logits=vgg_in_feature,
            vgg=vgg,
            h_labels=labels,
            h_logits=logits)
        cost = tf.reduce_mean(loss_dic["loss"], name="cost")
        # cost_angio = tf.reduce_mean(loss_dic["angio_SSIM"], name="angio_SSIM")
        # cost_perf = tf.reduce_mean(loss_dic["perf_SSIM"], name="perf_SSIM")

    # ========================================================================
    # ave_loss = tf.placeholder(tf.float32, name='loss')
    # ave_loss_perf = tf.placeholder(tf.float32, name='loss_perf')
    # ave_loss_angio = tf.placeholder(tf.float32, name='loss_angio')
    #
    # average_gradient_perf = tf.placeholder(tf.float32, name='grad_ave_perf')
    # average_gradient_angio = tf.placeholder(tf.float32, name='grad_ave_angio')
    #
    # ave_huber = tf.placeholder(tf.float32, name='huber')
    # restore the model
    sess = tf.Session()
    saver = tf.train.Saver()

    ckpt = tf.train.get_checkpoint_state(chckpnt_dir)
    saver.restore(sess, ckpt.model_checkpoint_path)

    copyfile('./test_synthesize_ssim_perf_angio.py',
             result_path + '/test_synthesize_ssim_perf_angio.py')

    _image_class = image_class(train_data,
                               bunch_of_images_no=1,
                               is_training=1,
                               patch_window=patch_window,
                               sample_no_per_bunch=1,
                               label_patch_size=label_patchs_size,
                               validation_total_sample=0)
    learning_rate = 1E-5
    extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(extra_update_ops):
        optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)
        # init = tf.global_variables_initializer()
    dic_perf0 = []
    dic_perf1 = []
    dic_perf2 = []
    dic_perf3 = []
    dic_perf4 = []
    dic_perf5 = []
    dic_perf6 = []

    dic_angio0 = []
    dic_angio1 = []
    dic_angio2 = []
    dic_angio3 = []
    dic_angio4 = []
    dic_angio5 = []
    dic_angio6 = []
    loss = 0
    for img_indx in range(len(test_set)):
        crush, noncrush, perf, angio, mri, segmentation_, spacing, direction, origin = _image_class.read_image_for_test(
            test_set=test_set,
            img_indx=img_indx,
            input_size=in_dim,
            final_layer=final_layer)
        [out] = sess.run(
            [y],
            feed_dict={
                img_row1: np.expand_dims(np.expand_dims(crush[0], 0), -1),
                img_row2: np.expand_dims(np.expand_dims(noncrush[1], 0), -1),
                img_row3: np.expand_dims(np.expand_dims(crush[2], 0), -1),
                img_row4: np.expand_dims(np.expand_dims(noncrush[3], 0), -1),
                img_row5: np.expand_dims(np.expand_dims(crush[4], 0), -1),
                img_row6: np.expand_dims(np.expand_dims(noncrush[5], 0), -1),
                img_row7: np.expand_dims(np.expand_dims(crush[6], 0), -1),
                img_row8: np.expand_dims(np.expand_dims(noncrush[7], 0), -1),
                mri_ph: np.expand_dims(np.expand_dims(mri, 0), -1),
                label1: np.expand_dims(np.expand_dims(perf[0], 0), -1),
                label2: np.expand_dims(np.expand_dims(perf[1], 0), -1),
                label3: np.expand_dims(np.expand_dims(perf[2], 0), -1),
                label4: np.expand_dims(np.expand_dims(perf[3], 0), -1),
                label5: np.expand_dims(np.expand_dims(perf[4], 0), -1),
                label6: np.expand_dims(np.expand_dims(perf[5], 0), -1),
                label7: np.expand_dims(np.expand_dims(perf[6], 0), -1),
                label8: np.expand_dims(np.expand_dims(angio[0], 0), -1),
                label9: np.expand_dims(np.expand_dims(angio[1], 0), -1),
                label10: np.expand_dims(np.expand_dims(angio[2], 0), -1),
                label11: np.expand_dims(np.expand_dims(angio[3], 0), -1),
                label12: np.expand_dims(np.expand_dims(angio[4], 0), -1),
                label13: np.expand_dims(np.expand_dims(angio[5], 0), -1),
                label14: np.expand_dims(np.expand_dims(angio[6], 0), -1),
                is_training: False,
                input_dim: patch_window,
                all_loss: -1.,
                angio_vgg_loss_tens: -1,  # vgg angio
                perf_vgg_loss_tens: -1,
                perf_vgg_tens0: -1,
                perf_vgg_tens1: -1,
                perf_vgg_tens2: -1,
                perf_vgg_tens3: -1,
                perf_vgg_tens4: -1,
                perf_vgg_tens5: -1,
                perf_vgg_tens6: -1,
                angio_vgg_tens0: -1,
                angio_vgg_tens1: -1,
                angio_vgg_tens2: -1,
                angio_vgg_tens3: -1,
                angio_vgg_tens4: -1,
                angio_vgg_tens5: -1,
                angio_vgg_tens6: -1,
                perf_huber_loss_tens: -1,
                angio_huber_loss_tens: -1,
                perf_huber_tens0: -1,
                perf_huber_tens1: -1,
                perf_huber_tens2: -1,
                perf_huber_tens3: -1,
                perf_huber_tens4: -1,
                perf_huber_tens5: -1,
                perf_huber_tens6: -1,
                angio_huber_tens0: -1,
                angio_huber_tens1: -1,
                angio_huber_tens2: -1,
                angio_huber_tens3: -1,
                angio_huber_tens4: -1,
                angio_huber_tens5: -1,
                angio_huber_tens6: -1,
            })

        for i in range(np.shape(out)[-1]):
            image = out[0, :, :, :, i]
            sitk_image = sitk.GetImageFromArray(image)
            res_dir = test_set[img_indx][0][0].split('/')[-2]
            if i == 0:
                os.mkdir(parent_path + 'results/' + res_dir)
            if i < 7:
                nm = 'perf'
            else:
                nm = 'angi'
            sitk_image.SetDirection(direction=direction)
            sitk_image.SetOrigin(origin=origin)
            sitk_image.SetSpacing(spacing=spacing)
            sitk.WriteImage(
                sitk_image, parent_path + 'results/' + res_dir + '/' + nm +
                '_' + str(i % 7) + '.mha')
            print(parent_path + 'results/' + res_dir + '/' + nm + '_' +
                  str(i % 7) + '.mha done!')
        for i in range(7):
            if i == 0:
                os.mkdir(parent_path + 'results/' + res_dir + '/GT/')
            sitk_angio = sitk.GetImageFromArray(angio[i])
            sitk_angio.SetDirection(direction=direction)
            sitk_angio.SetOrigin(origin=origin)
            sitk_angio.SetSpacing(spacing=spacing)
            sitk.WriteImage(
                sitk_angio, parent_path + 'results/' + res_dir + '/GT/angio_' +
                str(i) + '.mha')

            sitk_perf = sitk.GetImageFromArray(perf[i])
            sitk_perf.SetDirection(direction=direction)
            sitk_perf.SetOrigin(origin=origin)
            sitk_perf.SetSpacing(spacing=spacing)
            sitk.WriteImage(
                sitk_perf, parent_path + 'results/' + res_dir + '/GT/perf_' +
                str(i) + '.mha')
        a = 1
        dic_perf0.append(
            anly.analysis(out[0, :, :, :, 0], perf[i], 0, max_perf))
        dic_perf1.append(
            anly.analysis(out[0, :, :, :, 1], perf[i], 0, max_perf))
        dic_perf2.append(
            anly.analysis(out[0, :, :, :, 2], perf[i], 0, max_perf))
        dic_perf3.append(
            anly.analysis(out[0, :, :, :, 3], perf[i], 0, max_perf))
        dic_perf4.append(
            anly.analysis(out[0, :, :, :, 4], perf[i], 0, max_perf))
        dic_perf5.append(
            anly.analysis(out[0, :, :, :, 5], perf[i], 0, max_perf))
        dic_perf6.append(
            anly.analysis(out[0, :, :, :, 6], perf[i], 0, max_perf))

        dic_angio0.append(
            anly.analysis(out[0, :, :, :, 7], angio[i], 0, max_angio))
        dic_angio1.append(
            anly.analysis(out[0, :, :, :, 8], angio[i], 0, max_angio))
        dic_angio2.append(
            anly.analysis(out[0, :, :, :, 9], angio[i], 0, max_angio))
        dic_angio3.append(
            anly.analysis(out[0, :, :, :, 10], angio[i], 0, max_angio))
        dic_angio4.append(
            anly.analysis(out[0, :, :, :, 11], angio[i], 0, max_angio))
        dic_angio5.append(
            anly.analysis(out[0, :, :, :, 12], angio[i], 0, max_angio))
        dic_angio6.append(
            anly.analysis(out[0, :, :, :, 13], angio[i], 0, max_angio))
        if img_indx == 0:
            headers = dic_perf0[0].keys()
        dics = [
            dic_perf0, dic_perf1, dic_perf2, dic_perf3, dic_perf4, dic_perf5,
            dic_perf6, dic_angio0, dic_angio1, dic_angio2, dic_angio3,
            dic_angio4, dic_angio5, dic_angio6
        ]
    save_in_xlsx(parent_path, headers, dics=dics)
    # plt.imshow(out[0, int(gt_cube_size / 2), :, :, 0])
    # plt.figure()
    # loss += loss_train1
    # print('Loss_train: ', loss_train1)

    print('Total loss: ', loss / len(test_set))
コード例 #10
0
img_name = ''
label_name = ''
torso_tag = ''
data_path = ''
log_tag = 'synth-' + str(fold)
min_range = -1000
max_range = 3000
Logs = 'ASL_LOG/debug_Log/',
fold = fold
bunch_of_images_no = 20
patch_window = 77
label_patchs_size = 77
validation_samples = 100
_rd = _read_data(data=data,
                 img_name=img_name,
                 label_name=label_name,
                 dataset_path=data_path)

alpha_coeff = 1
'''read path of the images for train, test, and validation'''
train_data, validation_data, test_data = _rd.read_data_path()
_image_class_vl = image_class(validation_data,
                              bunch_of_images_no=bunch_of_images_no,
                              is_training=0,
                              patch_window=patch_window,
                              sample_no_per_bunch=sample_no,
                              label_patch_size=label_patchs_size,
                              validation_total_sample=validation_samples)
sitk.ReadImage(train_data)
plt.imshow()