コード例 #1
0
    def run_net(self):
        _rd = _read_data(data=self.data,
                         img_name=self.img_name,
                         label_name=self.label_name,
                         dataset_path=self.data_path)

        self.alpha_coeff = 1
        '''read path of the images for train, test, and validation'''
        train_data, validation_data, test_data = _rd.read_data_path()

        # ======================================
        bunch_of_images_no = 20
        sample_no = 60
        _image_class_vl = image_class(
            validation_data,
            bunch_of_images_no=bunch_of_images_no,
            is_training=0,
            patch_window=self.patch_window,
            sample_no_per_bunch=sample_no,
            label_patch_size=self.label_patchs_size,
            validation_total_sample=self.validation_samples)

        _patch_extractor_thread_vl = _extractor_thread(
            _image_class=_image_class_vl,
            patch_window=self.patch_window,
            label_patchs_size=self.label_patchs_size,
            mutex=settings.mutex,
            is_training=0,
            vl_sample_no=self.validation_samples)
        _fill_thread_vl = fill_thread(
            data=validation_data,
            _image_class=_image_class_vl,
            sample_no=sample_no,
            total_sample_no=self.validation_samples,
            label_patchs_size=self.label_patchs_size,
            mutex=settings.mutex,
            is_training=0,
            patch_extractor=_patch_extractor_thread_vl,
            fold=self.fold)

        _fill_thread_vl.start()
        _patch_extractor_thread_vl.start()
        _read_thread_vl = read_thread(
            _fill_thread_vl,
            mutex=settings.mutex,
            validation_sample_no=self.validation_samples,
            is_training=0)
        _read_thread_vl.start()
        # ======================================
        bunch_of_images_no = 20
        sample_no = 60
        _image_class = image_class(train_data,
                                   bunch_of_images_no=bunch_of_images_no,
                                   is_training=1,
                                   patch_window=self.patch_window,
                                   sample_no_per_bunch=sample_no,
                                   label_patch_size=self.label_patchs_size,
                                   validation_total_sample=0)

        patch_extractor_thread = _extractor_thread(
            _image_class=_image_class,
            patch_window=self.patch_window,
            label_patchs_size=self.label_patchs_size,
            mutex=settings.mutex,
            is_training=1)
        _fill_thread = fill_thread(train_data,
                                   _image_class,
                                   sample_no=sample_no,
                                   total_sample_no=self.sample_no,
                                   label_patchs_size=self.label_patchs_size,
                                   mutex=settings.mutex,
                                   is_training=1,
                                   patch_extractor=patch_extractor_thread,
                                   fold=self.fold)

        _fill_thread.start()
        patch_extractor_thread.start()

        _read_thread = read_thread(_fill_thread,
                                   mutex=settings.mutex,
                                   is_training=1)
        _read_thread.start()
        # ======================================
        # pre_bn=tf.placeholder(tf.float32,shape=[None,None,None,None,None])
        # image=tf.placeholder(tf.float32,shape=[self.batch_no,self.patch_window,self.patch_window,self.patch_window,1])
        # label=tf.placeholder(tf.float32,shape=[self.batch_no_validation,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size,2])
        # loss_coef=tf.placeholder(tf.float32,shape=[self.batch_no_validation,1,1,1])
        #
        # img_row1 = tf.placeholder(tf.float32, shape=[self.batch_no,self.patch_window,self.patch_window,self.patch_window, 1],name='img_row1')
        # img_row2 = tf.placeholder(tf.float32, shape=[self.batch_no,self.patch_window,self.patch_window,self.patch_window, 1],name='img_row2')
        # img_row3 = tf.placeholder(tf.float32, shape=[self.batch_no,self.patch_window,self.patch_window,self.patch_window, 1],name='img_row3')
        # img_row4 = tf.placeholder(tf.float32, shape=[self.batch_no,self.patch_window,self.patch_window,self.patch_window, 1],name='img_row4')
        # img_row5 = tf.placeholder(tf.float32, shape=[self.batch_no,self.patch_window,self.patch_window,self.patch_window, 1],name='img_row5')
        # img_row6 = tf.placeholder(tf.float32, shape=[self.batch_no,self.patch_window,self.patch_window,self.patch_window, 1],name='img_row6')
        # img_row7 = tf.placeholder(tf.float32, shape=[self.batch_no,self.patch_window,self.patch_window,self.patch_window, 1],name='img_row7')
        # img_row8 = tf.placeholder(tf.float32, shape=[self.batch_no,self.patch_window,self.patch_window,self.patch_window, 1],name='img_row8')
        #
        # #perfution 7time points original scale
        # label1 = tf.placeholder(tf.float32, shape=[self.batch_no,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size, 1],name='label1')
        # label2 = tf.placeholder(tf.float32, shape=[self.batch_no,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size, 1],name='label2')
        # label3 = tf.placeholder(tf.float32, shape=[self.batch_no,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size, 1],name='label3')
        # label4 = tf.placeholder(tf.float32, shape=[self.batch_no,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size, 1],name='label4')
        # label5 = tf.placeholder(tf.float32, shape=[self.batch_no,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size, 1],name='label5')
        # label6 = tf.placeholder(tf.float32, shape=[self.batch_no,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size, 1],name='label6')
        # # angio 7time points original scale
        # label7 = tf.placeholder(tf.float32, shape=[self.batch_no,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size, 1],name='label7')
        # label8 = tf.placeholder(tf.float32, shape=[self.batch_no,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size, 1],name='label8')
        # label9 = tf.placeholder(tf.float32, shape=[self.batch_no,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size, 1],name='label9')
        # label10 = tf.placeholder(tf.float32, shape=[self.batch_no,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size, 1],name='label10')
        # label11 = tf.placeholder(tf.float32, shape=[self.batch_no,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size, 1],name='label11')
        # label12 = tf.placeholder(tf.float32, shape=[self.batch_no,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size, 1],name='label12')
        # label13 = tf.placeholder(tf.float32, shape=[self.batch_no,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size, 1],name='label13')
        # label14 = tf.placeholder(tf.float32, shape=[self.batch_no,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size, 1],name='label14')

        img_row1 = tf.placeholder(tf.float32,
                                  shape=[None, None, None, None, 1],
                                  name='img_row1')
        img_row2 = tf.placeholder(tf.float32,
                                  shape=[None, None, None, None, 1],
                                  name='img_row2')
        img_row3 = tf.placeholder(tf.float32,
                                  shape=[None, None, None, None, 1],
                                  name='img_row3')
        img_row4 = tf.placeholder(tf.float32,
                                  shape=[None, None, None, None, 1],
                                  name='img_row4')
        img_row5 = tf.placeholder(tf.float32,
                                  shape=[None, None, None, None, 1],
                                  name='img_row5')
        img_row6 = tf.placeholder(tf.float32,
                                  shape=[None, None, None, None, 1],
                                  name='img_row6')
        img_row7 = tf.placeholder(tf.float32,
                                  shape=[None, None, None, None, 1],
                                  name='img_row7')
        img_row8 = tf.placeholder(tf.float32,
                                  shape=[None, None, None, None, 1],
                                  name='img_row8')

        label1 = tf.placeholder(tf.float32,
                                shape=[None, None, None, None, 1],
                                name='label1')
        label2 = tf.placeholder(tf.float32,
                                shape=[None, None, None, None, 1],
                                name='label2')
        label3 = tf.placeholder(tf.float32,
                                shape=[None, None, None, None, 1],
                                name='label3')
        label4 = tf.placeholder(tf.float32,
                                shape=[None, None, None, None, 1],
                                name='label4')
        label5 = tf.placeholder(tf.float32,
                                shape=[None, None, None, None, 1],
                                name='label5')
        label6 = tf.placeholder(tf.float32,
                                shape=[None, None, None, None, 1],
                                name='label6')
        label7 = tf.placeholder(tf.float32,
                                shape=[None, None, None, None, 1],
                                name='label7')
        label8 = tf.placeholder(tf.float32,
                                shape=[None, None, None, None, 1],
                                name='label8')
        label9 = tf.placeholder(tf.float32,
                                shape=[None, None, None, None, 1],
                                name='label9')
        label10 = tf.placeholder(tf.float32,
                                 shape=[None, None, None, None, 1],
                                 name='label10')
        label11 = tf.placeholder(tf.float32,
                                 shape=[None, None, None, None, 1],
                                 name='label11')
        label12 = tf.placeholder(tf.float32,
                                 shape=[None, None, None, None, 1],
                                 name='label12')
        label13 = tf.placeholder(tf.float32,
                                 shape=[None, None, None, None, 1],
                                 name='label13')
        label14 = tf.placeholder(tf.float32,
                                 shape=[None, None, None, None, 1],
                                 name='label14')

        is_training = tf.placeholder(tf.bool, name='is_training')
        input_dim = tf.placeholder(tf.int32, name='input_dim')

        forked_densenet = _forked_densenet2()

        perf_y,perf_loss_fm1,perf_loss_fm2,\
        angio_y,angio_loss_fm1,angio_loss_fm2=forked_densenet.densenet( img_row1=img_row1, img_row2=img_row2, img_row3=img_row3, img_row4=img_row4, img_row5=img_row5,
                     img_row6=img_row6, img_row7=img_row7, img_row8=img_row8,input_dim=input_dim,is_training=is_training)

        y_dirX = ((perf_y[:,
                          int(self.label_patchs_size / 2), :, :, 0,
                          np.newaxis]))
        y_dirX1 = ((perf_y[:,
                           int(self.label_patchs_size / 2), :, :, 1,
                           np.newaxis]))
        y_dirX2 = ((perf_y[:,
                           int(self.label_patchs_size / 2), :, :, 2,
                           np.newaxis]))
        y_dirX3 = ((perf_y[:,
                           int(self.label_patchs_size / 2), :, :, 3,
                           np.newaxis]))
        y_dirX4 = ((perf_y[:,
                           int(self.label_patchs_size / 2), :, :, 4,
                           np.newaxis]))
        y_dirX5 = ((perf_y[:,
                           int(self.label_patchs_size / 2), :, :, 5,
                           np.newaxis]))
        y_dirX6 = ((perf_y[:,
                           int(self.label_patchs_size / 2), :, :, 6,
                           np.newaxis]))
        y_dirX7 = ((angio_y[:,
                            int(self.label_patchs_size / 2), :, :, 0,
                            np.newaxis]))
        y_dirX8 = ((angio_y[:,
                            int(self.label_patchs_size / 2), :, :, 1,
                            np.newaxis]))
        y_dirX9 = ((angio_y[:,
                            int(self.label_patchs_size / 2), :, :, 2,
                            np.newaxis]))
        y_dirX10 = ((angio_y[:,
                             int(self.label_patchs_size / 2), :, :, 3,
                             np.newaxis]))
        y_dirX11 = ((angio_y[:,
                             int(self.label_patchs_size / 2), :, :, 4,
                             np.newaxis]))
        y_dirX12 = ((angio_y[:,
                             int(self.label_patchs_size / 2), :, :, 5,
                             np.newaxis]))
        y_dirX13 = ((angio_y[:,
                             int(self.label_patchs_size / 2), :, :, 6,
                             np.newaxis]))

        label_dirX1 = (label1[:,
                              int(self.label_patchs_size / 2), :, :, 0,
                              np.newaxis])
        label_dirX2 = (label2[:,
                              int(self.label_patchs_size / 2), :, :, 0,
                              np.newaxis])
        label_dirX3 = (label3[:,
                              int(self.label_patchs_size / 2), :, :, 0,
                              np.newaxis])
        label_dirX4 = (label4[:,
                              int(self.label_patchs_size / 2), :, :, 0,
                              np.newaxis])
        label_dirX5 = (label5[:,
                              int(self.label_patchs_size / 2), :, :, 0,
                              np.newaxis])
        label_dirX6 = (label6[:,
                              int(self.label_patchs_size / 2), :, :, 0,
                              np.newaxis])
        label_dirX7 = (label7[:,
                              int(self.label_patchs_size / 2), :, :, 0,
                              np.newaxis])
        label_dirX8 = (label8[:,
                              int(self.label_patchs_size / 2), :, :, 0,
                              np.newaxis])
        label_dirX9 = (label9[:,
                              int(self.label_patchs_size / 2), :, :, 0,
                              np.newaxis])
        label_dirX10 = (label10[:,
                                int(self.label_patchs_size / 2), :, :, 0,
                                np.newaxis])
        label_dirX11 = (label11[:,
                                int(self.label_patchs_size / 2), :, :, 0,
                                np.newaxis])
        label_dirX12 = (label12[:,
                                int(self.label_patchs_size / 2), :, :, 0,
                                np.newaxis])
        label_dirX13 = (label13[:,
                                int(self.label_patchs_size / 2), :, :, 0,
                                np.newaxis])
        label_dirX14 = (label14[:,
                                int(self.label_patchs_size / 2), :, :, 0,
                                np.newaxis])

        tf.summary.image('out0', y_dirX, 3)
        tf.summary.image('out1', y_dirX1, 3)
        tf.summary.image('out2', y_dirX2, 3)
        tf.summary.image('out3', y_dirX3, 3)
        tf.summary.image('out4', y_dirX4, 3)
        tf.summary.image('out5', y_dirX5, 3)
        tf.summary.image('out6', y_dirX6, 3)
        tf.summary.image('out7', y_dirX7, 3)
        tf.summary.image('out8', y_dirX8, 3)
        tf.summary.image('out9', y_dirX9, 3)
        tf.summary.image('out10', y_dirX10, 3)
        tf.summary.image('out11', y_dirX11, 3)
        tf.summary.image('out12', y_dirX12, 3)
        tf.summary.image('out13', y_dirX13, 3)

        tf.summary.image('groundtruth1', label_dirX1, 3)
        tf.summary.image('groundtruth2', label_dirX2, 3)
        tf.summary.image('groundtruth3', label_dirX3, 3)
        tf.summary.image('groundtruth4', label_dirX4, 3)
        tf.summary.image('groundtruth5', label_dirX5, 3)
        tf.summary.image('groundtruth6', label_dirX6, 3)
        tf.summary.image('groundtruth7', label_dirX7, 3)
        tf.summary.image('groundtruth8', label_dirX8, 3)
        tf.summary.image('groundtruth9', label_dirX9, 3)
        tf.summary.image('groundtruth10', label_dirX10, 3)
        tf.summary.image('groundtruth11', label_dirX11, 3)
        tf.summary.image('groundtruth12', label_dirX12, 3)
        tf.summary.image('groundtruth13', label_dirX13, 3)
        tf.summary.image('groundtruth14', label_dirX14, 3)

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)

        log_extttt = ''
        train_writer = tf.summary.FileWriter(self.LOGDIR + '/train' +
                                             log_extttt,
                                             graph=tf.get_default_graph())
        validation_writer = tf.summary.FileWriter(self.LOGDIR + '/validation' +
                                                  log_extttt,
                                                  graph=sess.graph)
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=1000)

        utils.backup_code(self.LOGDIR)
        '''AdamOptimizer:'''
        with tf.name_scope('MSE_perf_angio'):
            MSE_perf_angio = self.loss_instance.MSE_perf_angio(
                label1=label1,
                label2=label2,
                label3=label3,
                label4=label4,
                label5=label5,
                label6=label6,
                label7=label7,
                label8=label8,
                label9=label9,
                label10=label10,
                label11=label11,
                label12=label12,
                label13=label13,
                label14=label14,
                logit1=perf_y[:, :, :, :, 0, np.newaxis],
                logit2=perf_y[:, :, :, :, 1, np.newaxis],
                logit3=perf_y[:, :, :, :, 2, np.newaxis],
                logit4=perf_y[:, :, :, :, 3, np.newaxis],
                logit5=perf_y[:, :, :, :, 4, np.newaxis],
                logit6=perf_y[:, :, :, :, 5, np.newaxis],
                logit7=perf_y[:, :, :, :, 6, np.newaxis],
                logit8=angio_y[:, :, :, :, 0, np.newaxis],
                logit9=angio_y[:, :, :, :, 1, np.newaxis],
                logit10=angio_y[:, :, :, :, 2, np.newaxis],
                logit11=angio_y[:, :, :, :, 3, np.newaxis],
                logit12=angio_y[:, :, :, :, 4, np.newaxis],
                logit13=angio_y[:, :, :, :, 5, np.newaxis],
                logit14=angio_y[:, :, :, :, 6, np.newaxis],
                perf_loss_fm1=perf_loss_fm1,
                perf_loss_fm2=perf_loss_fm2,
                angio_loss_fm1=angio_loss_fm1,
                angio_loss_fm2=angio_loss_fm2,
            )
            cost = tf.reduce_mean(MSE_perf_angio, name="cost")
        tf.summary.scalar("cost", MSE_perf_angio)

        extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(extra_update_ops):
            optimizer = tf.train.AdamOptimizer(
                self.learning_rate).minimize(cost)

        sess.run(tf.global_variables_initializer())
        logging.debug('total number of variables %s' % (np.sum([
            np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()
        ])))

        summ = tf.summary.merge_all()
        loadModel = 0
        point = 0
        itr1 = 0
        if loadModel:
            chckpnt_dir = ''
            ckpt = tf.train.get_checkpoint_state(chckpnt_dir)
            saver.restore(sess, ckpt.model_checkpoint_path)
            point = np.int16(
                ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
            itr1 = point
        '''loop for epochs'''
        for epoch in range(self.total_epochs):
            while self.no_sample_per_each_itr * int(
                    point / self.no_sample_per_each_itr) < self.sample_no:
                print('0')
                print("epoch #: %d" % (epoch))
                startTime = time.time()

                step = 0
                # =============validation================
                if itr1 % self.display_validation_step == 0:
                    '''Validation: '''
                    loss_validation = 0
                    acc_validation = 0
                    validation_step = 0
                    dsc_validation = 0

                while (validation_step * self.batch_no_validation <
                       settings.validation_totalimg_patch):

                    [crush, noncrush, perf,
                     angio] = _image_class_vl.return_patches_vl(
                         validation_step * self.batch_no_validation,
                         (validation_step + 1) * self.batch_no_validation,
                         is_tr=False)
                    if (len(angio) < self.batch_no_validation):
                        _read_thread_vl.resume()
                        time.sleep(0.5)
                        continue
                    # continue

                    [loss_vali] = sess.run(
                        [cost],
                        feed_dict={
                            img_row1: crush[:, 0, :, :, :, :],
                            img_row2: noncrush[:, 1, :, :, :, :],
                            img_row3: crush[:, 2, :, :, :, :],
                            img_row4: noncrush[:, 3, :, :, :, :],
                            img_row5: crush[:, 4, :, :, :, :],
                            img_row6: noncrush[:, 5, :, :, :, :],
                            img_row7: crush[:, 6, :, :, :, :],
                            img_row8: noncrush[:, 7, :, :, :, :],
                            label1: perf[:, 0, :, :, :, :],
                            label2: perf[:, 1, :, :, :, :],
                            label3: perf[:, 2, :, :, :, :],
                            label4: perf[:, 3, :, :, :, :],
                            label5: perf[:, 4, :, :, :, :],
                            label6: perf[:, 5, :, :, :, :],
                            label7: perf[:, 6, :, :, :, :],
                            label8: angio[:, 0, :, :, :, :],
                            label9: angio[:, 1, :, :, :, :],
                            label10: angio[:, 2, :, :, :, :],
                            label11: angio[:, 3, :, :, :, :],
                            label12: angio[:, 4, :, :, :, :],
                            label13: angio[:, 5, :, :, :, :],
                            label14: angio[:, 6, :, :, :, :],
                            is_training: False,
                            input_dim: self.patch_window,
                        })
                    loss_validation += loss_vali

                    validation_step += 1
                    if np.isnan(dsc_validation) or np.isnan(
                            loss_validation) or np.isnan(acc_validation):
                        print('nan problem')
                    process = psutil.Process(os.getpid())
                    print(
                        '%d - > %d:   loss_validation: %f, memory_percent: %4s'
                        % (
                            validation_step,
                            validation_step * self.batch_no_validation,
                            loss_vali,
                            str(process.memory_percent()),
                        ))

                settings.queue_isready_vl = False
                acc_validation = acc_validation / (validation_step)
                loss_validation = loss_validation / (validation_step)
                dsc_validation = dsc_validation / (validation_step)

                if np.isnan(dsc_validation) or np.isnan(
                        loss_validation) or np.isnan(acc_validation):
                    print('nan problem')
                _fill_thread_vl.kill_thread()
                print(
                    '******Validation, step: %d , accuracy: %.4f, loss: %f*******'
                    % (itr1, acc_validation, loss_validation))
                [sum_validation] = sess.run(
                    [summ],
                    feed_dict={
                        img_row1: crush[:, 0, :, :, :, :],
                        img_row2: noncrush[:, 1, :, :, :, :],
                        img_row3: crush[:, 2, :, :, :, :],
                        img_row4: noncrush[:, 3, :, :, :, :],
                        img_row5: crush[:, 4, :, :, :, :],
                        img_row6: noncrush[:, 5, :, :, :, :],
                        img_row7: crush[:, 6, :, :, :, :],
                        img_row8: noncrush[:, 7, :, :, :, :],
                        label1: perf[:, 0, :, :, :, :],
                        label2: perf[:, 1, :, :, :, :],
                        label3: perf[:, 2, :, :, :, :],
                        label4: perf[:, 3, :, :, :, :],
                        label5: perf[:, 4, :, :, :, :],
                        label6: perf[:, 5, :, :, :, :],
                        label7: perf[:, 6, :, :, :, :],
                        label8: angio[:, 0, :, :, :, :],
                        label9: angio[:, 1, :, :, :, :],
                        label10: angio[:, 2, :, :, :, :],
                        label11: angio[:, 3, :, :, :, :],
                        label12: angio[:, 4, :, :, :, :],
                        label13: angio[:, 5, :, :, :, :],
                        label14: angio[:, 6, :, :, :, :],
                        is_training: False,
                        input_dim: self.patch_window,
                    })
                validation_writer.add_summary(sum_validation, point)
                print('end of validation---------%d' % (point))
                '''loop for training batches'''
                while (step * self.batch_no < self.no_sample_per_each_itr):

                    # [train_CT_image_patchs, train_GTV_label, train_Penalize_patch,loss_coef_weights] = _image_class.return_patches( self.batch_no)

                    [crush, noncrush, perf,
                     angio] = _image_class.return_patches_tr(self.batch_no)

                    if (len(angio) < self.batch_no):
                        time.sleep(0.5)
                        _read_thread.resume()
                        continue

                    [
                        loss_train1,
                        optimizing,
                        out,
                    ] = sess.run(
                        [
                            cost,
                            optimizer,
                            perf_y,
                        ],
                        feed_dict={
                            img_row1: crush[:, 0, :, :, :, :],
                            img_row2: noncrush[:, 1, :, :, :, :],
                            img_row3: crush[:, 2, :, :, :, :],
                            img_row4: noncrush[:, 3, :, :, :, :],
                            img_row5: crush[:, 4, :, :, :, :],
                            img_row6: noncrush[:, 5, :, :, :, :],
                            img_row7: crush[:, 6, :, :, :, :],
                            img_row8: noncrush[:, 7, :, :, :, :],
                            label1: perf[:, 0, :, :, :, :],
                            label2: perf[:, 1, :, :, :, :],
                            label3: perf[:, 2, :, :, :, :],
                            label4: perf[:, 3, :, :, :, :],
                            label5: perf[:, 4, :, :, :, :],
                            label6: perf[:, 5, :, :, :, :],
                            label7: perf[:, 6, :, :, :, :],
                            label8: angio[:, 0, :, :, :, :],
                            label9: angio[:, 1, :, :, :, :],
                            label10: angio[:, 2, :, :, :, :],
                            label11: angio[:, 3, :, :, :, :],
                            label12: angio[:, 4, :, :, :, :],
                            label13: angio[:, 5, :, :, :, :],
                            label14: angio[:, 6, :, :, :, :],
                            is_training: False,
                            input_dim: self.patch_window,
                        })

                    self.x_hist = self.x_hist + 1

                    [sum_train] = sess.run(
                        [summ],
                        feed_dict={
                            img_row1: crush[:, 0, :, :, :, :],
                            img_row2: noncrush[:, 1, :, :, :, :],
                            img_row3: crush[:, 2, :, :, :, :],
                            img_row4: noncrush[:, 3, :, :, :, :],
                            img_row5: crush[:, 4, :, :, :, :],
                            img_row6: noncrush[:, 5, :, :, :, :],
                            img_row7: crush[:, 6, :, :, :, :],
                            img_row8: noncrush[:, 7, :, :, :, :],
                            label1: perf[:, 0, :, :, :, :],
                            label2: perf[:, 1, :, :, :, :],
                            label3: perf[:, 2, :, :, :, :],
                            label4: perf[:, 3, :, :, :, :],
                            label5: perf[:, 4, :, :, :, :],
                            label6: perf[:, 5, :, :, :, :],
                            label7: perf[:, 6, :, :, :, :],
                            label8: angio[:, 0, :, :, :, :],
                            label9: angio[:, 1, :, :, :, :],
                            label10: angio[:, 2, :, :, :, :],
                            label11: angio[:, 3, :, :, :, :],
                            label12: angio[:, 4, :, :, :, :],
                            label13: angio[:, 5, :, :, :, :],
                            label14: angio[:, 6, :, :, :, :],
                            is_training: False,
                            input_dim: self.patch_window,
                        })
                    train_writer.add_summary(sum_train, point)
                    step = step + 1

                    process = psutil.Process(os.getpid())
                    print(
                        'point: %d, step*self.batch_no:%f , LR: %.15f, loss_train1:%f,memory_percent: %4s'
                        % (int(
                            (point)), step * self.batch_no, self.learning_rate,
                           loss_train1, str(process.memory_percent())))
                    point = int(
                        (point)
                    )  #(self.no_sample_per_each_itr/self.batch_no)*itr1+step

                    if point % 100 == 0:
                        '''saveing model inter epoch'''
                        chckpnt_path = os.path.join(
                            self.chckpnt_dir,
                            ('unet_inter_epoch%d_point%d.ckpt' %
                             (epoch, point)))
                        saver.save(sess, chckpnt_path, global_step=point)

                    itr1 = itr1 + 1
                    point = point + 1

            endTime = time.time()
            #==============end of epoch:
            '''saveing model after each epoch'''
            chckpnt_path = os.path.join(self.chckpnt_dir, 'unet.ckpt')
            saver.save(sess, chckpnt_path, global_step=epoch)

            print("End of epoch----> %d, elapsed time: %d" %
                  (epoch, endTime - startTime))
コード例 #2
0
    def run_net(self,loadModel = 0):

        # pre_bn=tf.placeholder(tf.float32,shape=[None,None,None,None,None])
        # image=tf.placeholder(tf.float32,shape=[self.batch_no,self.patch_window,self.patch_window,self.patch_window,1])
        # label=tf.placeholder(tf.float32,shape=[self.batch_no_validation,self.label_patchs_size,self.label_patchs_size,self.label_patchs_size,2])
        # loss_coef=tf.placeholder(tf.float32,shape=[self.batch_no_validation,1,1,1])
        # ===================================================================================
        _rd = _read_data(data=self.data,
                         img_name=self.img_name, label_name=self.label_name,
                         dataset_path=self.data_path,reverse=self.newdataset)

        self.alpha_coeff = 1
        '''read path of the images for train, test, and validation'''
        train_data, validation_data, test_data = _rd.read_data_path()

        # ======================================
        bunch_of_images_no = 20
        sample_no = 60
        _image_class_vl = image_class(validation_data,
                                      bunch_of_images_no=bunch_of_images_no,
                                      is_training=0,
                                      patch_window=self.patch_window,
                                      sample_no_per_bunch=sample_no,
                                      label_patch_size=self.label_patchs_size,
                                      validation_total_sample=self.validation_samples)

        _patch_extractor_thread_vl = _extractor_thread(_image_class=_image_class_vl,
                                                       patch_window=self.patch_window,
                                                       label_patchs_size=self.label_patchs_size,

                                                       mutex=settings.mutex,
                                                       is_training=0,
                                                       vl_sample_no=self.validation_samples
                                                       )
        _fill_thread_vl = fill_thread(data=validation_data,
                                      _image_class=_image_class_vl,
                                      sample_no=sample_no,
                                      total_sample_no=self.validation_samples,
                                      label_patchs_size=self.label_patchs_size,
                                      mutex=settings.mutex,
                                      is_training=0,
                                      patch_extractor=_patch_extractor_thread_vl,
                                      fold=self.fold)

        _fill_thread_vl.start()
        _patch_extractor_thread_vl.start()
        _read_thread_vl = read_thread(_fill_thread_vl,
                                      mutex=settings.mutex,
                                      validation_sample_no=self.validation_samples,
                                      is_training=0)
        _read_thread_vl.start()
        # ======================================
        bunch_of_images_no = 20
        sample_no = 60
        _image_class = image_class(train_data
                                   , bunch_of_images_no=bunch_of_images_no,
                                   is_training=1,
                                   patch_window=self.patch_window,
                                   sample_no_per_bunch=sample_no,
                                   label_patch_size=self.label_patchs_size,
                                   validation_total_sample=0)

        patch_extractor_thread = _extractor_thread(_image_class=_image_class,
                                                   patch_window=self.patch_window,
                                                   label_patchs_size=self.label_patchs_size,
                                                   mutex=settings.mutex, is_training=1)
        _fill_thread = fill_thread(train_data,
                                   _image_class,
                                   sample_no=sample_no, total_sample_no=self.sample_no,
                                   label_patchs_size=self.label_patchs_size,
                                   mutex=settings.mutex,
                                   is_training=1,
                                   patch_extractor=patch_extractor_thread,
                                   fold=self.fold)
        #
        _fill_thread.start()
        patch_extractor_thread.start()

        _read_thread = read_thread(_fill_thread, mutex=settings.mutex, is_training=1)
        _read_thread.start()
        # ===================================================================================
        img_row1 = tf.placeholder(tf.float32,
                                  shape=[self.batch_no, self.patch_window, self.patch_window, self.patch_window, 1],name='img_row1')
        img_row2 = tf.placeholder(tf.float32,
                                  shape=[self.batch_no, self.patch_window, self.patch_window, self.patch_window, 1],name='img_row2')
        img_row3 = tf.placeholder(tf.float32,
                                  shape=[self.batch_no, self.patch_window, self.patch_window, self.patch_window, 1],name='img_row3')
        img_row4 = tf.placeholder(tf.float32,
                                  shape=[self.batch_no, self.patch_window, self.patch_window, self.patch_window, 1],name='img_row4')
        img_row5 = tf.placeholder(tf.float32,
                                  shape=[self.batch_no, self.patch_window, self.patch_window, self.patch_window, 1],name='img_row5')
        img_row6 = tf.placeholder(tf.float32,
                                  shape=[self.batch_no, self.patch_window, self.patch_window, self.patch_window, 1],name='img_row6')
        img_row7 = tf.placeholder(tf.float32,
                                  shape=[self.batch_no, self.patch_window, self.patch_window, self.patch_window, 1],name='img_row7')
        img_row8 = tf.placeholder(tf.float32,
                                  shape=[self.batch_no, self.patch_window, self.patch_window, self.patch_window, 1],name='img_row8')

        mri_ph = tf.placeholder(tf.float32,
                                  shape=[self.batch_no, self.patch_window, self.patch_window, self.patch_window, 1],name='mri')

        label1 = tf.placeholder(tf.float32, shape=[self.batch_no, self.label_patchs_size, self.label_patchs_size,
                                                   self.label_patchs_size, 1], name='label1')
        label2 = tf.placeholder(tf.float32, shape=[self.batch_no, self.label_patchs_size, self.label_patchs_size,
                                                   self.label_patchs_size, 1], name='label2')
        label3 = tf.placeholder(tf.float32, shape=[self.batch_no, self.label_patchs_size, self.label_patchs_size,
                                                   self.label_patchs_size, 1], name='label3')
        label4 = tf.placeholder(tf.float32, shape=[self.batch_no, self.label_patchs_size, self.label_patchs_size,
                                                   self.label_patchs_size, 1], name='label4')
        label5 = tf.placeholder(tf.float32, shape=[self.batch_no, self.label_patchs_size, self.label_patchs_size,
                                                   self.label_patchs_size, 1], name='label5')
        label6 = tf.placeholder(tf.float32, shape=[self.batch_no, self.label_patchs_size, self.label_patchs_size,
                                                   self.label_patchs_size, 1], name='label6')
        label7 = tf.placeholder(tf.float32, shape=[self.batch_no, self.label_patchs_size, self.label_patchs_size,
                                                   self.label_patchs_size, 1], name='label7')
        label8 = tf.placeholder(tf.float32, shape=[self.batch_no, self.label_patchs_size, self.label_patchs_size,
                                                   self.label_patchs_size, 1], name='label8')
        label9 = tf.placeholder(tf.float32, shape=[self.batch_no, self.label_patchs_size, self.label_patchs_size,
                                                   self.label_patchs_size, 1], name='label9')
        label10 = tf.placeholder(tf.float32, shape=[self.batch_no, self.label_patchs_size, self.label_patchs_size,
                                                    self.label_patchs_size, 1], name='label10')
        label11 = tf.placeholder(tf.float32, shape=[self.batch_no, self.label_patchs_size, self.label_patchs_size,
                                                    self.label_patchs_size, 1], name='label11')
        label12 = tf.placeholder(tf.float32, shape=[self.batch_no, self.label_patchs_size, self.label_patchs_size,
                                                    self.label_patchs_size, 1], name='label12')
        label13 = tf.placeholder(tf.float32, shape=[self.batch_no, self.label_patchs_size, self.label_patchs_size,
                                                    self.label_patchs_size, 1], name='label13')
        label14 = tf.placeholder(tf.float32, shape=[self.batch_no, self.label_patchs_size, self.label_patchs_size,
                                                    self.label_patchs_size, 1], name='label14')

        # loss_placeholder = tf.placeholder(tf.float32,
        #                                   shape=[self.batch_no, self.label_patchs_size, self.label_patchs_size,
        #                                          self.label_patchs_size, 1])

        # img_row1 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='img_row1')
        # img_row2 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='img_row2')
        # img_row3 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='img_row3')
        # img_row4 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='img_row4')
        # img_row5 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='img_row5')
        # img_row6 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='img_row6')
        # img_row7 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='img_row7')
        # img_row8 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='img_row8')
        #
        # label1 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='label1')
        # label2 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='label2')
        # label3 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='label3')
        # label4 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='label4')
        # label5 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='label5')
        # label6 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='label6')
        # label7 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='label7')
        # label8 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='label8')
        # label9 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='label9')
        # label10 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='label10')
        # label11 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='label11')
        # label12 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='label12')
        # label13 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='label13')
        # label14 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1],name='label14')

        is_training = tf.placeholder(tf.bool, name='is_training')
        input_dim = tf.placeholder(tf.int32, name='input_dim')

        # unet=_unet()
        multi_stage_densenet = _multi_stage_densenet()

        y,loss_upsampling11,loss_upsampling22 = multi_stage_densenet.multi_stage_densenet(img_row1=img_row1,
                              img_row2=img_row2,
                              img_row3=img_row3,
                              img_row4=img_row4,
                              img_row5=img_row5,
                              img_row6=img_row6,
                              img_row7=img_row7,
                              img_row8=img_row8,
                              input_dim=input_dim,
                              mri = mri_ph,
                              is_training=is_training)
        # self.vgg = vgg_feature_maker()
        # self.vgg_y0 = self.vgg.feed_img(loss_placeholder)
        # self.vgg_y0 = self.vgg.feed_img(y[:, :, :, :, 0]).copy()
        # self.vgg_y1 = self.vgg.feed_img(y[:, :, :, :, 1]).copy()
        # self.vgg_y2 = self.vgg.feed_img(y[:, :, :, :, 2]).copy()
        # self.vgg_y3 = self.vgg.feed_img(y[:, :, :, :, 3]).copy()
        # self.vgg_y4 = self.vgg.feed_img(y[:, :, :, :, 4]).copy()
        # self.vgg_y5 = self.vgg.feed_img(y[:, :, :, :, 5]).copy()
        # self.vgg_y6 = self.vgg.feed_img(y[:, :, :, :, 6]).copy()
        # self.vgg_y7 = self.vgg.feed_img(y[:, :, :, :, 7])
        # self.vgg_y8 = self.vgg.feed_img(y[:, :, :, :, 8])
        # self.vgg_y9 = self.vgg.feed_img(y[:, :, :, :, 9])
        # self.vgg_y10 = self.vgg.feed_img(y[:, :, :, :, 10])
        # self.vgg_y11 = self.vgg.feed_img(y[:, :, :, :, 11])
        # self.vgg_y12 = self.vgg.feed_img(y[:, :, :, :, 12])
        # self.vgg_y13 = self.vgg.feed_img(y[:, :, :, :, 13])

        # self.vgg_label0 = self.vgg.feed_img(label1[:,:,:,:,0]).copy()
        # self.vgg_label1 = self.vgg.feed_img(label2[:,:,:,:,0]).copy()
        # self.vgg_label2 = self.vgg.feed_img(label3[:,:,:,:,0]).copy()
        # self.vgg_label3 = self.vgg.feed_img(label4[:,:,:,:,0]).copy()
        # self.vgg_label4 = self.vgg.feed_img(label5[:,:,:,:,0]).copy()
        # self.vgg_label5 = self.vgg.feed_img(label6[:,:,:,:,0]).copy()
        # self.vgg_label6 = self.vgg.feed_img(label7[:,:,:,:,0]).copy()
        # self.vgg_label7 = self.vgg.feed_img(label8[:,:,:,:,0])
        # self.vgg_label8 = self.vgg.feed_img(label9[:,:,:,:,0])
        # self.vgg_label9 = self.vgg.feed_img(label10[:,:,:,:,0])
        # self.vgg_label10 = self.vgg.feed_img(label11[:,:,:,:,0])
        # self.vgg_label11 = self.vgg.feed_img(label12[:,:,:,:,0])
        # self.vgg_label12 = self.vgg.feed_img(label13[:,:,:,:,0])
        # self.vgg_label13 = self.vgg.feed_img(label14[:,:,:,:,0])

        y_dirX = ((y[:, int(self.label_patchs_size / 2), :, :, 0, np.newaxis]))
        y_dirX1 = ((y[:, int(self.label_patchs_size / 2), :, :, 1, np.newaxis]))
        y_dirX2 = ((y[:, int(self.label_patchs_size / 2), :, :, 2, np.newaxis]))
        y_dirX3 = ((y[:, int(self.label_patchs_size / 2), :, :, 3, np.newaxis]))
        y_dirX4 = ((y[:, int(self.label_patchs_size / 2), :, :, 4, np.newaxis]))
        y_dirX5 = ((y[:, int(self.label_patchs_size / 2), :, :, 5, np.newaxis]))
        y_dirX6 = ((y[:, int(self.label_patchs_size / 2), :, :, 6, np.newaxis]))
        y_dirX7 = ((y[:, int(self.label_patchs_size / 2), :, :, 7, np.newaxis]))
        y_dirX8 = ((y[:, int(self.label_patchs_size / 2), :, :, 8, np.newaxis]))
        y_dirX9 = ((y[:, int(self.label_patchs_size / 2), :, :, 9, np.newaxis]))
        y_dirX10 = ((y[:, int(self.label_patchs_size / 2), :, :, 10, np.newaxis]))
        y_dirX11 = ((y[:, int(self.label_patchs_size / 2), :, :, 11, np.newaxis]))
        y_dirX12 = ((y[:, int(self.label_patchs_size / 2), :, :, 12, np.newaxis]))
        y_dirX13 = ((y[:, int(self.label_patchs_size / 2), :, :, 13, np.newaxis]))

        label_dirX1 = (label1[:, int(self.label_patchs_size / 2), :, :, 0, np.newaxis])
        label_dirX2 = (label2[:, int(self.label_patchs_size / 2), :, :, 0, np.newaxis])
        label_dirX3 = (label3[:, int(self.label_patchs_size / 2), :, :, 0, np.newaxis])
        label_dirX4 = (label4[:, int(self.label_patchs_size / 2), :, :, 0, np.newaxis])
        label_dirX5 = (label5[:, int(self.label_patchs_size / 2), :, :, 0, np.newaxis])
        label_dirX6 = (label6[:, int(self.label_patchs_size / 2), :, :, 0, np.newaxis])
        label_dirX7 = (label7[:, int(self.label_patchs_size / 2), :, :, 0, np.newaxis])
        label_dirX8 = (label8[:, int(self.label_patchs_size / 2), :, :, 0, np.newaxis])
        label_dirX9 = (label9[:, int(self.label_patchs_size / 2), :, :, 0, np.newaxis])
        label_dirX10 = (label10[:, int(self.label_patchs_size / 2), :, :, 0, np.newaxis])
        label_dirX11 = (label11[:, int(self.label_patchs_size / 2), :, :, 0, np.newaxis])
        label_dirX12 = (label12[:, int(self.label_patchs_size / 2), :, :, 0, np.newaxis])
        label_dirX13 = (label13[:, int(self.label_patchs_size / 2), :, :, 0, np.newaxis])
        label_dirX14 = (label14[:, int(self.label_patchs_size / 2), :, :, 0, np.newaxis])

        tf.summary.image('out0', y_dirX, 3)
        tf.summary.image('out1', y_dirX1, 3)
        tf.summary.image('out2', y_dirX2, 3)
        tf.summary.image('out3', y_dirX3, 3)
        tf.summary.image('out4', y_dirX4, 3)
        tf.summary.image('out5', y_dirX5, 3)
        tf.summary.image('out6', y_dirX6, 3)
        tf.summary.image('out7', y_dirX7, 3)
        tf.summary.image('out8', y_dirX8, 3)
        tf.summary.image('out9', y_dirX9, 3)
        tf.summary.image('out10', y_dirX10, 3)
        tf.summary.image('out11', y_dirX11, 3)
        tf.summary.image('out12', y_dirX12, 3)
        tf.summary.image('out13', y_dirX13, 3)

        tf.summary.image('groundtruth1', label_dirX1, 3)
        tf.summary.image('groundtruth2', label_dirX2, 3)
        tf.summary.image('groundtruth3', label_dirX3, 3)
        tf.summary.image('groundtruth4', label_dirX4, 3)
        tf.summary.image('groundtruth5', label_dirX5, 3)
        tf.summary.image('groundtruth6', label_dirX6, 3)
        tf.summary.image('groundtruth7', label_dirX7, 3)
        tf.summary.image('groundtruth8', label_dirX8, 3)
        tf.summary.image('groundtruth9', label_dirX9, 3)
        tf.summary.image('groundtruth10', label_dirX10, 3)
        tf.summary.image('groundtruth11', label_dirX11, 3)
        tf.summary.image('groundtruth12', label_dirX12, 3)
        tf.summary.image('groundtruth13', label_dirX13, 3)
        tf.summary.image('groundtruth14', label_dirX14, 3)

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)

        train_writer = tf.summary.FileWriter(self.LOGDIR + '/train' , graph=tf.get_default_graph())
        validation_writer = tf.summary.FileWriter(self.LOGDIR + '/validation' , graph=sess.graph)
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=10000)
        if loadModel==0:
            utils.backup_code(self.LOGDIR)
        labels = []
        labels.append(label1)
        labels.append(label2)
        labels.append(label3)
        labels.append(label4)
        labels.append(label5)
        labels.append(label6)
        labels.append(label7)
        labels.append(label8)
        labels.append(label9)
        labels.append(label10)
        labels.append(label11)
        labels.append(label12)
        labels.append(label13)
        labels.append(label14)

        logits = []
        logits.append(y[:, :, :, :, 0, np.newaxis])
        logits.append(y[:, :, :, :, 1, np.newaxis])
        logits.append(y[:, :, :, :, 2, np.newaxis])
        logits.append(y[:, :, :, :, 3, np.newaxis])
        logits.append(y[:, :, :, :, 4, np.newaxis])
        logits.append(y[:, :, :, :, 5, np.newaxis])
        logits.append(y[:, :, :, :, 6, np.newaxis])

        logits.append(y[:, :, :, :, 7, np.newaxis])
        logits.append(y[:, :, :, :, 8, np.newaxis])
        logits.append(y[:, :, :, :, 9, np.newaxis])
        logits.append(y[:, :, :, :, 10, np.newaxis])
        logits.append(y[:, :, :, :, 11, np.newaxis])
        logits.append(y[:, :, :, :, 12, np.newaxis])
        logits.append(y[:, :, :, :, 13, np.newaxis])

        stage1 = []
        stage1.append(loss_upsampling11[:, :, :, :, 0, np.newaxis])
        stage1.append(loss_upsampling11[:, :, :, :, 1, np.newaxis])
        stage1.append(loss_upsampling11[:, :, :, :, 2, np.newaxis])
        stage1.append(loss_upsampling11[:, :, :, :, 3, np.newaxis])
        stage1.append(loss_upsampling11[:, :, :, :, 4, np.newaxis])
        stage1.append(loss_upsampling11[:, :, :, :, 5, np.newaxis])
        stage1.append(loss_upsampling11[:, :, :, :, 6, np.newaxis])
        stage1.append(loss_upsampling11[:, :, :, :, 7, np.newaxis])
        stage1.append(loss_upsampling11[:, :, :, :, 8, np.newaxis])
        stage1.append(loss_upsampling11[:, :, :, :, 9, np.newaxis])
        stage1.append(loss_upsampling11[:, :, :, :, 10, np.newaxis])
        stage1.append(loss_upsampling11[:, :, :, :, 11, np.newaxis])
        stage1.append(loss_upsampling11[:, :, :, :, 12, np.newaxis])
        stage1.append(loss_upsampling11[:, :, :, :, 13, np.newaxis])

        stage2 = []
        stage2.append(loss_upsampling22[:, :, :, :, 0, np.newaxis])
        stage2.append(loss_upsampling22[:, :, :, :, 1, np.newaxis])
        stage2.append(loss_upsampling22[:, :, :, :, 2, np.newaxis])
        stage2.append(loss_upsampling22[:, :, :, :, 3, np.newaxis])
        stage2.append(loss_upsampling22[:, :, :, :, 4, np.newaxis])
        stage2.append(loss_upsampling22[:, :, :, :, 5, np.newaxis])
        stage2.append(loss_upsampling22[:, :, :, :, 6, np.newaxis])
        stage2.append(loss_upsampling22[:, :, :, :, 7, np.newaxis])
        stage2.append(loss_upsampling22[:, :, :, :, 8, np.newaxis])
        stage2.append(loss_upsampling22[:, :, :, :, 9, np.newaxis])
        stage2.append(loss_upsampling22[:, :, :, :, 10, np.newaxis])
        stage2.append(loss_upsampling22[:, :, :, :, 11, np.newaxis])
        stage2.append(loss_upsampling22[:, :, :, :, 12, np.newaxis])
        stage2.append(loss_upsampling22[:, :, :, :, 13, np.newaxis])



        with tf.name_scope('Loss'):
            loss_dic = self.loss_instance.loss_selector('Multistage_ssim_perf_angio_loss',
                                                        labels=labels,logits=logits,
                                                        stage1=stage1,
                                                        stage2=stage2)
            cost = tf.reduce_mean(loss_dic["loss"], name="cost")

        with tf.variable_scope("summary"):
            tf.summary.scalar("Loss/loss", cost)

        # ============================================
        all_loss = tf.placeholder(tf.float32, name='loss')
        with tf.name_scope('validation'):
            ave_loss = all_loss
        tf.summary.scalar("ave_loss", ave_loss)
        '''AdamOptimizer:'''
        extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(extra_update_ops):
            optimizer = tf.train.AdamOptimizer(self.learning_rate).minimize(cost)

        sess.run(tf.global_variables_initializer())
        logging.debug('total number of variables %s' % (
            np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])))

        summ = tf.summary.merge_all()

        point = 0
        if loadModel:
            chckpnt_dir = '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/sythesize_code/ASL_LOG/MRI_in/experiment-2/unet_checkpoints/'
            ckpt = tf.train.get_checkpoint_state(chckpnt_dir)
            saver.restore(sess, ckpt.model_checkpoint_path)
            point = (ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
            point = int(int(point) / self.display_train_step) * self.display_train_step

        '''loop for epochs'''
        for epoch in range(self.total_epochs):
            while self.no_sample_per_each_itr * int(point / self.no_sample_per_each_itr) < self.sample_no:
                print('0')
                print("epoch #: %d" % (epoch))
                startTime = time.time()

                step = 0
                # =============validation================
                if point % self.display_validation_step == 0:
                    '''Validation: '''
                    loss_validation = 0
                    loss_validation_p_ssim=0
                    loss_validation_p_perceptual=0
                    acc_validation = 0
                    validation_step = 0
                    dsc_validation = 0


                while (validation_step * self.batch_no_validation < settings.validation_totalimg_patch):

                    [crush, noncrush, perf, angio,mri,segmentation] = _image_class_vl.return_patches_vl(
                        validation_step * self.batch_no_validation,
                        (validation_step + 1) * self.batch_no_validation, is_tr=False
                        )
                    if (len(segmentation) < self.batch_no_validation):
                        _read_thread_vl.resume()
                        time.sleep(0.5)
                        continue
                    # continue

                    [loss_vali] = sess.run([cost],
                                  feed_dict={img_row1: crush[:, 0, :, :, :, :],
                                             img_row2: noncrush[:, 1, :, :, :, :],
                                             img_row3: crush[:, 2, :, :, :, :],
                                             img_row4: noncrush[:, 3, :, :, :, :],
                                             img_row5: crush[:, 4, :, :, :, :],
                                             img_row6: noncrush[:, 5, :, :, :, :],
                                             img_row7: crush[:, 6, :, :, :, :],
                                             img_row8: noncrush[:, 7, :, :, :, :],
                                             mri_ph: mri[:, 0, :, :, :, :],
                                             label1: perf[:, 0, :, :, :, :],
                                             label2: perf[:, 1, :, :, :, :],
                                             label3: perf[:, 2, :, :, :, :],
                                             label4: perf[:, 3, :, :, :, :],
                                             label5: perf[:, 4, :, :, :, :],
                                             label6: perf[:, 5, :, :, :, :],
                                             label7: perf[:, 6, :, :, :, :],
                                             label8: angio[:, 0, :, :, :, :],
                                             label9: angio[:, 1, :, :, :, :],
                                             label10: angio[:, 2, :, :, :, :],
                                             label11: angio[:, 3, :, :, :, :],
                                             label12: angio[:, 4, :, :, :, :],
                                             label13: angio[:, 5, :, :, :, :],
                                             label14: angio[:, 6, :, :, :, :],
                                             is_training: False,
                                             input_dim: self.patch_window,
                                             all_loss: -1.,
                                             })
                    loss_validation += loss_vali

                    validation_step += 1
                    if np.isnan(dsc_validation) or np.isnan(loss_validation) or np.isnan(acc_validation):
                        print('nan problem')
                    process = psutil.Process(os.getpid())
                    print(
                        '%d - > %d:   loss_validation: %f, memory_percent: %4s' % (
                            validation_step, validation_step * self.batch_no_validation
                            , loss_vali, str(process.memory_percent()),
                        ))

                settings.queue_isready_vl = False
                acc_validation = acc_validation / (validation_step)
                loss_validation = loss_validation / (validation_step)
                loss_validation_p_ssim = loss_validation_p_ssim / (validation_step)
                loss_validation_p_perceptual = loss_validation_p_perceptual / (validation_step)
                dsc_validation = dsc_validation / (validation_step)

                if np.isnan(dsc_validation) or np.isnan(loss_validation) or np.isnan(acc_validation):
                    print('nan problem')
                _fill_thread_vl.kill_thread()
                print('******Validation, step: %d , accuracy: %.4f, loss: %f*******' % (
                    point, acc_validation, loss_validation))
                [sum_validation] = sess.run([summ],
                                            feed_dict={img_row1: crush[:, 0, :, :, :, :],
                                                       img_row2: noncrush[:, 1, :, :, :, :],
                                                       img_row3: crush[:, 2, :, :, :, :],
                                                       img_row4: noncrush[:, 3, :, :, :, :],
                                                       img_row5: crush[:, 4, :, :, :, :],
                                                       img_row6: noncrush[:, 5, :, :, :, :],
                                                       img_row7: crush[:, 6, :, :, :, :],
                                                       img_row8: noncrush[:, 7, :, :, :, :],
                                                       mri_ph: mri[:, 0, :, :, :, :],
                                                       label1: perf[:, 0, :, :, :, :],
                                                       label2: perf[:, 1, :, :, :, :],
                                                       label3: perf[:, 2, :, :, :, :],
                                                       label4: perf[:, 3, :, :, :, :],
                                                       label5: perf[:, 4, :, :, :, :],
                                                       label6: perf[:, 5, :, :, :, :],
                                                       label7: perf[:, 6, :, :, :, :],
                                                       label8: angio[:, 0, :, :, :, :],
                                                       label9: angio[:, 1, :, :, :, :],
                                                       label10: angio[:, 2, :, :, :, :],
                                                       label11: angio[:, 3, :, :, :, :],
                                                       label12: angio[:, 4, :, :, :, :],
                                                       label13: angio[:, 5, :, :, :, :],
                                                       label14: angio[:, 6, :, :, :, :],
                                                       is_training: False,
                                                       input_dim: self.patch_window,
                                                       all_loss: loss_validation,
                                                       })
                validation_writer.add_summary(sum_validation, point)
                print('end of validation---------%d' % (point))

                '''loop for training batches'''
                while (step * self.batch_no < self.no_sample_per_each_itr):

                    # [train_CT_image_patchs, train_GTV_label, train_Penalize_patch,loss_coef_weights] = _image_class.return_patches( self.batch_no)

                    [crush, noncrush, perf, angio,mri,segmentation] = _image_class.return_patches_tr(self.batch_no)

                    if (len(segmentation) < self.batch_no):
                        time.sleep(0.5)
                        _read_thread.resume()
                        continue

                    if point % self.display_train_step == 0:
                        '''train: '''
                        train_step=0

                    while (train_step <self.display_train_step):
                        if train_step==0:
                            average_loss_train1=0
                        [loss_train1, optimizing, out ] = sess.run([cost, optimizer, y,],
                                      feed_dict={img_row1: crush[:, 0, :, :, :, :],
                                                 img_row2: noncrush[:, 1, :, :, :, :],
                                                 img_row3: crush[:, 2, :, :, :, :],
                                                 img_row4: noncrush[:, 3, :, :, :, :],
                                                 img_row5: crush[:, 4, :, :, :, :],
                                                 img_row6: noncrush[:, 5, :, :, :, :],
                                                 img_row7: crush[:, 6, :, :, :, :],
                                                 img_row8: noncrush[:, 7, :, :, :, :],
                                                 mri_ph: mri[:, 0, :, :, :, :],
                                                 label1: perf[:, 0, :, :, :, :],
                                                 label2: perf[:, 1, :, :, :, :],
                                                 label3: perf[:, 2, :, :, :, :],
                                                 label4: perf[:, 3, :, :, :, :],
                                                 label5: perf[:, 4, :, :, :, :],
                                                 label6: perf[:, 5, :, :, :, :],
                                                 label7: perf[:, 6, :, :, :, :],
                                                 label8: angio[:, 0, :, :, :, :],
                                                 label9: angio[:, 1, :, :, :, :],
                                                 label10: angio[:, 2, :, :, :, :],
                                                 label11: angio[:, 3, :, :, :, :],
                                                 label12: angio[:, 4, :, :, :, :],
                                                 label13: angio[:, 5, :, :, :, :],
                                                 label14: angio[:, 6, :, :, :, :],
                                                 is_training: False,
                                                 input_dim: self.patch_window,
                                                 all_loss: -1,
                                                 })
                        average_loss_train1+=loss_train1

                        train_step=train_step+1
                        print(
                            'point: %d, step*self.batch_no:%f , LR: %.15f, loss_train1:%f,memory_percent: %4s, ' % (
                                int((point+train_step)),
                                step * self.batch_no, self.learning_rate, loss_train1,
                                str(process.memory_percent())))

                    average_loss_train1 /= self.display_train_step


                    [sum_train] = sess.run([summ],
                                           feed_dict={img_row1: crush[:, 0, :, :, :, :],
                                                      img_row2: noncrush[:, 1, :, :, :, :],
                                                      img_row3: crush[:, 2, :, :, :, :],
                                                      img_row4: noncrush[:, 3, :, :, :, :],
                                                      img_row5: crush[:, 4, :, :, :, :],
                                                      img_row6: noncrush[:, 5, :, :, :, :],
                                                      img_row7: crush[:, 6, :, :, :, :],
                                                      img_row8: noncrush[:, 7, :, :, :, :],
                                                      mri_ph: mri[:, 0, :, :, :, :],
                                                      label1: perf[:, 0, :, :, :, :],
                                                      label2: perf[:, 1, :, :, :, :],
                                                      label3: perf[:, 2, :, :, :, :],
                                                      label4: perf[:, 3, :, :, :, :],
                                                      label5: perf[:, 4, :, :, :, :],
                                                      label6: perf[:, 5, :, :, :, :],
                                                      label7: perf[:, 6, :, :, :, :],
                                                      label8: angio[:, 0, :, :, :, :],
                                                      label9: angio[:, 1, :, :, :, :],
                                                      label10: angio[:, 2, :, :, :, :],
                                                      label11: angio[:, 3, :, :, :, :],
                                                      label12: angio[:, 4, :, :, :, :],
                                                      label13: angio[:, 5, :, :, :, :],
                                                      label14: angio[:, 6, :, :, :, :],
                                                      is_training: False,
                                                      input_dim: self.patch_window,
                                                      all_loss: loss_train1,


                                                      })
                    train_writer.add_summary(sum_train, point)
                    step = step + 1

                    point = point + self.display_train_step
                    if point % 200 == 0:
                        break
                    process = psutil.Process(os.getpid())
                    print(
                        'point: %d, step*self.batch_no:%f , LR: %.15f, loss_train1:%f,memory_percent: %4s' % (
                            int((point)),
                            step * self.batch_no, self.learning_rate, loss_train1,
                            str(process.memory_percent())))
                    point = int((point))  # (self.no_sample_per_each_itr/self.batch_no)*itr1+step

                    if point % 100 == 0:
                        '''saveing model inter epoch'''
                        chckpnt_path = os.path.join(self.chckpnt_dir,
                                                    ('unet_inter_epoch%d_point%d.ckpt' % (epoch, point)))
                        saver.save(sess, chckpnt_path, global_step=point)



            endTime = time.time()
            # ==============end of epoch:
            '''saveing model after each epoch'''
            chckpnt_path = os.path.join(self.chckpnt_dir, 'unet.ckpt')
            saver.save(sess, chckpnt_path, global_step=epoch)

            print("End of epoch----> %d, elapsed time: %d" % (epoch, endTime - startTime))