def run_net(self, no_averages):

        self.alpha_coeff = 1
        '''read path of the images for train, test, and validation'''
        _rd = _read_data(self.data_path_AMUC, self.data_path_LUMC)
        train_data, validation_data, test_data = _rd.read_data_path(
            average_no=no_averages)

        # ======================================
        bunch_of_images_no = 1
        _image_class_vl = image_class(validation_data,
                                      bunch_of_images_no=bunch_of_images_no,
                                      is_training=0,
                                      inp_size=self.asl_size,
                                      out_size=self.pet_size)
        _patch_extractor_thread_vl = _patch_extractor_thread._patch_extractor_thread(
            _image_class=_image_class_vl,
            img_no=bunch_of_images_no,
            mutex=settings.mutex,
            is_training=0)
        _fill_thread_vl = fill_thread.fill_thread(
            validation_data,
            _image_class_vl,
            mutex=settings.mutex,
            is_training=0,
            patch_extractor=_patch_extractor_thread_vl,
        )

        _read_thread_vl = read_thread.read_thread(
            _fill_thread_vl,
            mutex=settings.mutex,
            validation_sample_no=self.validation_samples,
            is_training=0)
        _fill_thread_vl.start()
        _patch_extractor_thread_vl.start()
        _read_thread_vl.start()
        # ======================================
        bunch_of_images_no = 14
        _image_class_tr = image_class(train_data,
                                      bunch_of_images_no=bunch_of_images_no,
                                      is_training=1,
                                      inp_size=self.asl_size,
                                      out_size=self.pet_size)
        _patch_extractor_thread_tr = _patch_extractor_thread._patch_extractor_thread(
            _image_class=_image_class_tr,
            img_no=bunch_of_images_no,
            mutex=settings.mutex,
            is_training=1,
        )
        _fill_thread = fill_thread.fill_thread(
            train_data,
            _image_class_tr,
            mutex=settings.mutex,
            is_training=1,
            patch_extractor=_patch_extractor_thread_tr,
        )
        _read_thread = read_thread.read_thread(_fill_thread,
                                               mutex=settings.mutex,
                                               is_training=1)
        _fill_thread.start()
        _patch_extractor_thread_tr.start()
        _read_thread.start()
        # ======================================
        # asl_plchld= tf.placeholder(tf.float32, shape=[None, None, None, 1])
        # t1_plchld= tf.placeholder(tf.float32, shape=[None, None, None, 1])
        # pet_plchld= tf.placeholder(tf.float32, shape=[None, None, None, 1])
        asl_plchld = tf.placeholder(
            tf.float32, shape=[self.batch_no, self.asl_size, self.asl_size, 1])
        t1_plchld = tf.placeholder(
            tf.float32, shape=[self.batch_no, self.asl_size, self.asl_size, 1])
        pet_plchld = tf.placeholder(
            tf.float32, shape=[self.batch_no, self.pet_size, self.pet_size, 1])
        asl_out_plchld = tf.placeholder(
            tf.float32, shape=[self.batch_no, self.pet_size, self.pet_size, 1])
        residual_attention_map = tf.placeholder(
            tf.float32, shape=[self.batch_no, self.asl_size, self.asl_size, 1])

        ave_loss_vali = tf.placeholder(tf.float32, name='ave_loss_vali')
        # True: train in a multi-task fashion, False: train in a single-task fashion
        hybrid_training_flag = tf.placeholder(tf.bool,
                                              name='hybrid_training_flag')

        is_training = tf.placeholder(tf.bool, name='is_training')
        is_training_bn = tf.placeholder(tf.bool, name='is_training_bn')

        # cnn_net = unet()  # create object
        # y,augmented_data = cnn_net.unet(t1=t1_plchld, asl=asl_plchld, pet=pet_plchld, is_training_bn=is_training_bn)
        msdensnet = multi_stage_densenet()
        asl_y, pet_y, att_map = msdensnet.multi_stage_densenet(
            asl_img=asl_plchld,
            t1_img=t1_plchld,
            pet_img=pet_plchld,
            hybrid_training_flag=hybrid_training_flag,
            input_dim=self.asl_size,
            is_training=is_training,
            config=self.config,
            residual_attention_map=residual_attention_map)

        show_img = asl_plchld[:, :, :, 0, np.newaxis]
        tf.summary.image('00: input_asl', show_img, 3)

        show_img = t1_plchld[:, :, :, 0, np.newaxis]
        tf.summary.image('01: input_t1', show_img, 3)

        show_img = pet_plchld[:, :, :, 0, np.newaxis]
        tf.summary.image('02: target_pet', show_img, 3)

        show_img = asl_out_plchld[:, :, :, 0, np.newaxis]
        tf.summary.image('03: target_asl', show_img, 3)
        #
        show_img = asl_y[:, :, :, 0, np.newaxis]
        tf.summary.image('04: output_asl', show_img, 3)

        show_img = pet_y[:, :, :, 0, np.newaxis]
        tf.summary.image('05: output_pet', show_img, 3)

        # -----------------
        # show_img = loss_upsampling11[:, :, :, 0, np.newaxis]
        # tf.summary.image('04: loss_upsampling11', show_img, 3)
        # #
        # show_img = loss_upsampling22[:, :, :, 0, np.newaxis]
        # tf.summary.image('05: loss_upsampling22', show_img, 3)

        print('*****************************************')
        print('*****************************************')
        print('*****************************************')
        sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
        # devices = sess.list_devices()
        # print(devices)
        # print(device_lib.list_local_devices())
        print('*****************************************')
        print('*****************************************')
        print('*****************************************')

        train_writer = tf.summary.FileWriter(self.LOGDIR + '/train',
                                             graph=tf.get_default_graph())
        validation_writer = tf.summary.FileWriter(self.LOGDIR + '/validation',
                                                  graph=sess.graph)
        try:
            os.mkdir(self.LOGDIR + 'code/')
            copyfile('./run_hybrid_multitask_net_hr.py',
                     self.LOGDIR + 'code/run_hybrid_multitask_net_hr.py')
            copyfile('./submit_job_hybrid_multitask_hr.py',
                     self.LOGDIR + 'code/submit_job_hybrid_multitask_hr.pys')
            shutil.copytree('./functions/', self.LOGDIR + 'code/functions/')
        except:
            a = 1

        # validation_writer.flush()
        extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=1000)
        # train_writer.close()
        # validation_writer.close()

        loadModel = 0
        # self.loss = ssim_loss()
        alpha = .84

        with tf.name_scope('cost'):
            ssim_asl = tf.reduce_mean(
                1 - SSIM(x1=asl_out_plchld, x2=asl_y, max_val=34.0)[0])
            loss_asl = alpha * ssim_asl + (1 - alpha) * tf.reduce_mean(
                huber(labels=asl_out_plchld, logit=asl_y))

            ssim_pet = tf.reduce_mean(
                1 - SSIM(x1=pet_plchld, x2=pet_y, max_val=2.1)[0])
            loss_pet = alpha * ssim_pet + (1 - alpha) * tf.reduce_mean(
                huber(labels=pet_plchld, logit=pet_y))
            # loss_pet = tf.cond(hybrid_training_flag,
            #                    lambda:alpha * ssim_pet + (1 - alpha) * tf.reduce_mean(huber(labels=pet_plchld, logit=pet_y)),
            #                    lambda:tf.stop_gradient(alpha * ssim_pet + (1 - alpha) * tf.reduce_mean(huber(labels=pet_plchld, logit=pet_y))))

            # cost = tf.cond(hybrid_training_flag, lambda: loss_asl + loss_pet,
            #                lambda: loss_asl )
            # cost = tf.cond(hybrid_training_flag, lambda: ssim_asl + ssim_pet,
            #                lambda: ssim_asl )

            cost_withpet = tf.reduce_mean(loss_asl + loss_pet)

            cost_withoutpet = loss_asl

            # hybrid: 0 without pet, 1 with pet

        # with tf.name_scope('cost'):
        #     ssim_asl = tf.reduce_mean(1 - SSIM(x1=asl_out_plchld, x2=asl_y, max_val=34.0)[0])
        #     # loss_asl = alpha * ssim_asl + (1 - alpha) * tf.reduce_mean(huber(labels=asl_out_plchld, logit=asl_y))
        #     loss_asl =ssim_asl
        #
        #     ssim_pet = tf.reduce_mean(1 - SSIM(x1=pet_plchld, x2=pet_y, max_val=2.1)[0])
        #     # loss_pet = alpha * ssim_pet + (1 - alpha) * tf.reduce_mean(huber(labels=pet_plchld, logit=pet_y))
        #     loss_pet =ssim_pet
        #
        #     # cost = tf.cond(hybrid_training_flag, lambda: loss_asl+tf.stop_gradient(loss_pet), lambda: loss_asl+loss_pet)
        #     # cost = tf.cond(hybrid_training_flag,
        #     #                lambda: loss_asl+0*tf.stop_gradient(loss_pet),  #true of 1 without pet
        #     #                lambda: loss_asl+loss_pet) #false or 0 with pet
        #
        #     cost = loss_pet

        tf.summary.scalar("cost_withoutpet", cost_withoutpet)
        tf.summary.scalar("cost_withpet", cost_withpet)
        # tf.summary.scalar("denominator", denominator)
        extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

        with tf.control_dependencies(extra_update_ops):
            optimizer_withpet = tf.train.AdamOptimizer(
                learning_rate=self.learning_rate, ).minimize(cost_withpet)
            optimizer_withoutpet = tf.train.AdamOptimizer(
                learning_rate=self.learning_rate, ).minimize(cost_withoutpet)

        with tf.name_scope('validation'):
            average_validation_loss = ave_loss_vali

        tf.summary.scalar("average_validation_loss", average_validation_loss)
        sess.run(tf.global_variables_initializer())

        logging.debug('total number of variables %s' % (np.sum([
            np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()
        ])))

        summ = tf.summary.merge_all()

        point = 0
        itr1 = 0
        if loadModel:
            chckpnt_dir = ''
            ckpt = tf.train.get_checkpoint_state(chckpnt_dir)
            saver.restore(sess, ckpt.model_checkpoint_path)
            point = np.int16(
                ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
            itr1 = point

        # with tf.Session() as sess:
        print("Number of trainable parameters: %d" %
              self.count_number_trainable_params())

        # patch_radius = 49
        '''loop for epochs'''

        for itr in range(self.total_iteration):
            while self.no_sample_per_each_itr * int(
                    point / self.no_sample_per_each_itr) < self.sample_no:
                print("epoch #: %d" % (settings.epochs_no))
                startTime = time.time()
                step = 0
                self.beta_coeff = 1 + 1 * np.exp(-point / 2000)

                # =============validation================
                if itr1 % self.display_validation_step == 0:
                    '''Validation: '''
                    loss_validation = 0
                    acc_validation = 0
                    validation_step = 0
                    dsc_validation = 0

                    while (validation_step * self.batch_no_validation <
                           settings.validation_totalimg_patch):

                        [
                            validation_asl_slices, validation_pet_slices,
                            validation_t1_slices
                        ] = _image_class_vl.return_patches_validation(
                            validation_step * self.batch_no_validation,
                            (validation_step + 1) * self.batch_no_validation)
                        if (len(validation_asl_slices) <
                                self.batch_no_validation) | (
                                    len(validation_pet_slices) <
                                    self.batch_no_validation) | (
                                        len(validation_t1_slices) <
                                        self.batch_no_validation):
                            _read_thread_vl.resume()
                            time.sleep(0.5)
                            # print('sleep 3 validation')
                            continue
                        # if len(validation_pet_slices):
                        #     hybrid_training_f = True
                        # else:
                        #     hybrid_training_f = False
                        tic = time.time()

                        [loss_vali] = sess.run(
                            [cost_withpet],
                            feed_dict={
                                asl_plchld:
                                validation_asl_slices,
                                t1_plchld:
                                validation_t1_slices,
                                pet_plchld:
                                validation_pet_slices,
                                asl_out_plchld:
                                validation_asl_slices[:,
                                                      int(self.asl_size / 2) -
                                                      int(self.pet_size / 2) -
                                                      1:int(self.asl_size /
                                                            2) +
                                                      int(self.pet_size / 2),
                                                      int(self.asl_size / 2) -
                                                      int(self.pet_size / 2) -
                                                      1:int(self.asl_size /
                                                            2) +
                                                      int(self.pet_size /
                                                          2), :],
                                is_training:
                                False,
                                ave_loss_vali:
                                -1,
                                is_training_bn:
                                False,
                                hybrid_training_flag:
                                False,
                                residual_attention_map: (np.ones([
                                    self.batch_no, self.asl_size,
                                    self.asl_size, 1
                                ]))
                            })
                        elapsed = time.time() - tic
                        loss_validation += loss_vali
                        validation_step += 1
                        if np.isnan(dsc_validation) or np.isnan(
                                loss_validation) or np.isnan(acc_validation):
                            print('nan problem')
                        process = psutil.Process(os.getpid())

                        print(
                            '%d - > %d: elapsed_time:%d  loss_validation: %f, memory_percent: %4s'
                            % (
                                validation_step,
                                validation_step * self.batch_no_validation,
                                elapsed,
                                loss_vali,
                                str(process.memory_percent()),
                            ))

                        # end while
                    settings.queue_isready_vl = False
                    acc_validation = acc_validation / (validation_step)
                    loss_validation = loss_validation / (validation_step)
                    dsc_validation = dsc_validation / (validation_step)
                    if np.isnan(dsc_validation) or np.isnan(
                            loss_validation) or np.isnan(acc_validation):
                        print('nan problem')
                    _fill_thread_vl.kill_thread()
                    print(
                        '******Validation, step: %d , accuracy: %.4f, loss: %f*******'
                        % (itr1, acc_validation, loss_validation))

                    [sum_validation] = sess.run(
                        [summ],
                        feed_dict={
                            asl_plchld:
                            validation_asl_slices,
                            t1_plchld:
                            validation_t1_slices,
                            pet_plchld:
                            validation_pet_slices,
                            asl_out_plchld:
                            validation_asl_slices[:,
                                                  int(self.asl_size / 2) -
                                                  int(self.pet_size / 2) -
                                                  1:int(self.asl_size / 2) +
                                                  int(self.pet_size / 2),
                                                  int(self.asl_size / 2) -
                                                  int(self.pet_size / 2) -
                                                  1:int(self.asl_size / 2) +
                                                  int(self.pet_size / 2), :],
                            is_training:
                            False,
                            ave_loss_vali:
                            loss_validation,
                            is_training_bn:
                            False,
                            hybrid_training_flag:
                            False,
                            residual_attention_map: (np.ones([
                                self.batch_no, self.asl_size, self.asl_size, 1
                            ]))
                        })

                    validation_writer.add_summary(sum_validation, point)
                    validation_writer.flush()
                    print('end of validation---------%d' % (point))
                    # end if
                '''loop for training batches'''
                db = ''
                patch_step = 0
                while (step * self.batch_no < self.no_sample_per_each_itr):
                    if patch_step < 5:  # hybrid: 0 without pet, 1 with pet
                        hybrid_training_f = True
                        db = 'AMUC'
                    elif patch_step < 10:
                        hybrid_training_f = False
                        db = 'LUMC'
                    else:
                        patch_step = 0
                    # hybrid_training_f = True
                    patch_step = patch_step + 1
                    [train_asl_slices, train_pet_slices,
                     train_t1_slices] = _image_class_tr.return_patches(
                         self.batch_no, hybrid_training_f)
                    # if not hybrid_training_f:
                    #     continue



                    if (len(train_asl_slices) < self.batch_no) | (len(train_pet_slices) < self.batch_no) \
                            | (len(train_t1_slices) < self.batch_no):
                        # |(len(train_t1_slices)<self.batch_no):
                        time.sleep(0.5)
                        _read_thread.resume()
                        continue

                    tic = time.time()

                    # hybrid_training_f=True
                    if hybrid_training_f:  #AMUC
                        [loss_train1, opt, att_map1] = sess.run(
                            [cost_withpet, optimizer_withpet, att_map],
                            feed_dict={
                                asl_plchld:
                                train_asl_slices,
                                t1_plchld:
                                train_t1_slices,
                                pet_plchld:
                                train_pet_slices,
                                asl_out_plchld:
                                train_asl_slices[:,
                                                 int(self.asl_size / 2) -
                                                 int(self.pet_size / 2) -
                                                 1:int(self.asl_size / 2) +
                                                 int(self.pet_size / 2),
                                                 int(self.asl_size / 2) -
                                                 int(self.pet_size / 2) -
                                                 1:int(self.asl_size / 2) +
                                                 int(self.pet_size / 2), :],
                                is_training:
                                True,
                                ave_loss_vali:
                                -1,
                                is_training_bn:
                                True,
                                hybrid_training_flag:
                                hybrid_training_f,
                                residual_attention_map: (np.ones([
                                    self.batch_no, self.asl_size,
                                    self.asl_size, 1
                                ]))
                            })

                        #----
                        att_map_padded = np.zeros(
                            [self.batch_no, self.asl_size, self.asl_size, 1])
                        att_map_padded[:,
                                       int((self.asl_size - self.pet_size) /
                                           2 - 1):int((self.asl_size -
                                                       self.pet_size) / 2 +
                                                      self.pet_size - 1),
                                       int((self.asl_size - self.pet_size) /
                                           2 - 1):int((self.asl_size -
                                                       self.pet_size) / 2 +
                                                      self.pet_size -
                                                      1), :] = att_map1

                        #----
                        [loss_train1, opt] = sess.run(
                            [cost_withpet, optimizer_withpet],
                            feed_dict={
                                asl_plchld:
                                train_asl_slices,
                                t1_plchld:
                                train_t1_slices,
                                pet_plchld:
                                train_pet_slices,
                                asl_out_plchld:
                                train_asl_slices[:,
                                                 int(self.asl_size / 2) -
                                                 int(self.pet_size / 2) -
                                                 1:int(self.asl_size / 2) +
                                                 int(self.pet_size / 2),
                                                 int(self.asl_size / 2) -
                                                 int(self.pet_size / 2) -
                                                 1:int(self.asl_size / 2) +
                                                 int(self.pet_size / 2), :],
                                is_training:
                                True,
                                ave_loss_vali:
                                -1,
                                is_training_bn:
                                True,
                                hybrid_training_flag:
                                hybrid_training_f,
                                residual_attention_map:
                                att_map_padded
                            })
                    else:  #LUMC
                        [loss_train1, opt, att_map1] = sess.run(
                            [cost_withoutpet, optimizer_withoutpet, att_map],
                            feed_dict={
                                asl_plchld:
                                train_asl_slices,
                                t1_plchld:
                                train_t1_slices,
                                pet_plchld:
                                train_pet_slices,
                                asl_out_plchld:
                                train_asl_slices[:,
                                                 int(self.asl_size / 2) -
                                                 int(self.pet_size / 2) -
                                                 1:int(self.asl_size / 2) +
                                                 int(self.pet_size / 2),
                                                 int(self.asl_size / 2) -
                                                 int(self.pet_size / 2) -
                                                 1:int(self.asl_size / 2) +
                                                 int(self.pet_size / 2), :],
                                is_training:
                                True,
                                ave_loss_vali:
                                -1,
                                is_training_bn:
                                True,
                                hybrid_training_flag:
                                hybrid_training_f,
                                residual_attention_map: (np.ones([
                                    self.batch_no, self.asl_size,
                                    self.asl_size, 1
                                ]))
                            })
                        # ----
                        att_map_padded = np.zeros(
                            [self.batch_no, self.asl_size, self.asl_size, 1])
                        att_map_padded[:,
                                       int((self.asl_size - self.pet_size) /
                                           2 - 1):int((self.asl_size -
                                                       self.pet_size) / 2 +
                                                      self.pet_size - 1),
                                       int((self.asl_size - self.pet_size) /
                                           2 - 1):int((self.asl_size -
                                                       self.pet_size) / 2 +
                                                      self.pet_size -
                                                      1), :] = att_map1

                        # ----
                        [loss_train1, opt] = sess.run(
                            [
                                cost_withoutpet,
                                optimizer_withoutpet,
                            ],
                            feed_dict={
                                asl_plchld:
                                train_asl_slices,
                                t1_plchld:
                                train_t1_slices,
                                pet_plchld:
                                train_pet_slices,
                                asl_out_plchld:
                                train_asl_slices[:,
                                                 int(self.asl_size / 2) -
                                                 int(self.pet_size / 2) -
                                                 1:int(self.asl_size / 2) +
                                                 int(self.pet_size / 2),
                                                 int(self.asl_size / 2) -
                                                 int(self.pet_size / 2) -
                                                 1:int(self.asl_size / 2) +
                                                 int(self.pet_size / 2), :],
                                is_training:
                                True,
                                ave_loss_vali:
                                -1,
                                is_training_bn:
                                True,
                                hybrid_training_flag:
                                hybrid_training_f,
                                residual_attention_map:
                                att_map_padded
                            })
                    elapsed = time.time() - tic
                    [sum_train] = sess.run(
                        [summ],
                        feed_dict={
                            asl_plchld:
                            train_asl_slices,
                            t1_plchld:
                            train_t1_slices,
                            pet_plchld:
                            train_pet_slices,
                            asl_out_plchld:
                            train_asl_slices[:,
                                             int(self.asl_size / 2) -
                                             int(self.pet_size / 2) -
                                             1:int(self.asl_size / 2) +
                                             int(self.pet_size / 2),
                                             int(self.asl_size / 2) -
                                             int(self.pet_size / 2) -
                                             1:int(self.asl_size / 2) +
                                             int(self.pet_size / 2), :],
                            is_training:
                            False,
                            ave_loss_vali:
                            loss_train1,
                            is_training_bn:
                            False,
                            hybrid_training_flag:
                            hybrid_training_f,
                            residual_attention_map: (np.ones([
                                self.batch_no, self.asl_size, self.asl_size, 1
                            ]))
                        })
                    train_writer.add_summary(sum_train, point)
                    train_writer.flush()
                    step = step + 1

                    process = psutil.Process(os.getpid())

                    print(
                        'point: %d, elapsed_time:%d step*self.batch_no:%f , LR: %.15f, loss_train1:%f,memory_percent: %4s, %s'
                        % (int((point)), elapsed, step * self.batch_no,
                           self.learning_rate, loss_train1,
                           str(process.memory_percent()), db))

                    point = int(
                        (point)
                    )  # (self.no_sample_per_each_itr/self.batch_no)*itr1+step

                    if point % 100 == 0:
                        '''saveing model inter epoch'''
                        chckpnt_path = os.path.join(
                            self.chckpnt_dir,
                            ('densenet_unet_inter_epoch%d_point%d.ckpt' %
                             (itr, point)))
                        saver.save(sess, chckpnt_path, global_step=point)

                    itr1 = itr1 + 1
                    point = point + 1

            endTime = time.time()

            # ==============end of epoch:
            '''saveing model after each epoch'''
            chckpnt_path = os.path.join(self.chckpnt_dir, 'densenet_unet.ckpt')
            saver.save(sess, chckpnt_path, global_step=itr)
            print("End of epoch----> %d, elapsed time: %d" %
                  (settings.epochs_no, endTime - startTime))
def test_all_nets(out_dir, Log, which_data):
    data_path_AMUC = "/exports/lkeb-hpc/syousefi/Data/ASL2PET_high_res/AMUC_high_res/"
    data_path_LUMC = "/exports/lkeb-hpc/syousefi/Data/ASL2PET_high_res/LUMC_high_res/"
    
    _rd = _read_data(data_path_AMUC,data_path_LUMC)

    train_data, validation_data, test_data = _rd.read_data_path(0)

    if which_data == 1:
        data = validation_data
    elif which_data == 2:
        data = test_data
    elif which_data == 3:
        data = train_data

    asl_plchld = tf.placeholder(tf.float32, shape=[None, asl_size, asl_size, 1])
    t1_plchld = tf.placeholder(tf.float32, shape=[None, asl_size, asl_size, 1])
    pet_plchld = tf.placeholder(tf.float32, shape=[None, pet_size, pet_size, 1])
    asl_out_plchld = tf.placeholder(tf.float32, shape=[None, pet_size, pet_size, 1])
    hybrid_training_flag = tf.placeholder(tf.bool, name='hybrid_training_flag')
    ave_loss_vali = tf.placeholder(tf.float32)

    is_training = tf.placeholder(tf.bool, name='is_training')
    is_training_bn = tf.placeholder(tf.bool, name='is_training_bn')

    msdensnet = multi_stage_densenet()
    asl_y1,asl_y2, pet_y = msdensnet.multi_stage_densenet(asl_img=asl_plchld,
                                                  t1_img=t1_plchld,
                                                  pet_img=pet_plchld,
                                                  hybrid_training_flag=hybrid_training_flag,
                                                  input_dim=asl_size,
                                                  is_training=is_training,
                                                  config=config,
                                                  )
    alpha = .84
    with tf.name_scope('cost'):
        ssim_asl1 = tf.reduce_mean(1 - SSIM(x1=asl_out_plchld, x2=asl_y1, max_val=34.0)[0])
        loss_asl1 = alpha * ssim_asl1 + (1 - alpha) * tf.reduce_mean(huber(labels=asl_out_plchld, logit=asl_y1))

        ssim_asl = tf.reduce_mean(1 - SSIM(x1=asl_out_plchld, x2=asl_y2, max_val=34.0)[0])
        loss_asl = alpha * ssim_asl + (1 - alpha) * tf.reduce_mean(huber(labels=asl_out_plchld, logit=asl_y2))

        ssim_pet = tf.reduce_mean(1 - SSIM(x1=pet_plchld, x2=pet_y, max_val=2.1)[0])
        loss_pet = alpha * ssim_pet + (1 - alpha) * tf.reduce_mean(huber(labels=pet_plchld, logit=pet_y))

    cost_withpet = tf.reduce_mean(loss_asl + loss_asl1 + loss_pet)

    cost_withoutpet = loss_asl + loss_asl1

    sess = tf.Session()
    saver = tf.train.Saver()
    parent_path = '/exports/lkeb-hpc/syousefi/Code/'
    chckpnt_dir = parent_path + Log + 'unet_checkpoints/'
    ckpt = tf.train.get_checkpoint_state(chckpnt_dir)
    saver.restore(sess, ckpt.model_checkpoint_path)
    _meas = _measure()
    copyfile('./test_semisupervised_multitast_hr_residual_asl.py',
             parent_path + Log + out_dir + 'test_semisupervised_multitast_hr_residual_asl.py')

    _image_class = image_class(data,
                               bunch_of_images_no=1,
                               is_training=0,
                               inp_size=asl_size,
                               out_size=pet_size)
    list_ssim = []
    list_name = []
    list_ssim_NC = []
    list_ssim_HC = []
    for scan in range(len(data)):
        ss = str(data[scan]['asl']).split("/")
        imm = _image_class.read_image(data[scan])
        try:
            os.mkdir(parent_path + Log + out_dir + ss[-3])
            os.mkdir(parent_path + Log + out_dir + ss[-3] + '/' + ss[-1].split(".")[0].split("ASL_")[1])
        except:
            a = 1

        for img_indx in range(np.shape(imm[3])[0]):
            print('img_indx:%s' % (img_indx))

            name = ss[-3] + '_' + ss[-2] + '_' + str(img_indx)
            # name = (ss[10] + '_' + ss[11] + '_' + ss[12].split('%')[0]).split('_CT')[0]
            tic = time.time()
            t1 = imm[3][img_indx,
                 int(imm[3].shape[-1]/2) - int(asl_size / 2) - 1:int(imm[3].shape[-1]/2)  + int(asl_size / 2),
                 int(imm[3].shape[-1] / 2) - int(asl_size / 2) - 1:int(imm[3].shape[-1]/2)  + int(asl_size / 2)]
            t1 = t1[np.newaxis, ..., np.newaxis]
            asl = imm[4][np.newaxis, img_indx,
                  int(imm[3].shape[-1] / 2) - int(asl_size / 2) - 1:int(imm[3].shape[-1] / 2) + int(asl_size / 2),
                  int(imm[3].shape[-1] / 2) - int(asl_size / 2) - 1:int(imm[3].shape[-1] / 2)+ int(asl_size / 2), np.newaxis]
            pet = imm[5][np.newaxis, img_indx,
                  int(imm[3].shape[-1] / 2)- int(pet_size / 2) - 1:int(imm[3].shape[-1] / 2) + int(pet_size / 2),
                  int(imm[3].shape[-1] / 2)- int(pet_size / 2) - 1:int(imm[3].shape[-1] / 2) + int(pet_size / 2), np.newaxis]

            [loss,pet_out,asl_out1,asl_out2] = sess.run([ssim_pet,pet_y,asl_y1,asl_y2],
                                   feed_dict={asl_plchld: asl,
                                              t1_plchld: t1,
                                              pet_plchld: pet,
                                              asl_out_plchld: asl[:,
                                                 int(asl_size / 2) - int(pet_size / 2) - 1:
                                                 int(asl_size / 2) + int(pet_size / 2),
                                                 int(asl_size / 2) - int(pet_size / 2) - 1:
                                                 int(asl_size / 2) + int(pet_size / 2), :],
                                              is_training: False,
                                              ave_loss_vali: -1,
                                              is_training_bn: False,
                                              hybrid_training_flag: False
                                              })

            ssim = 1 - loss
            list_ssim.append(ssim)
            str_nm = (ss[-3] + '_' + ss[-1].split(".")[0].split("ASL_")[1] + '_t1_' + name)
            if 'HN' in str_nm:
                list_ssim_NC.append(ssim)
            elif 'HY' in str_nm:
                list_ssim_HC.append(ssim)
            list_name.append(ss[-3] + '_' + ss[-1].split(".")[0].split("ASL_")[1] + '_t1_' + name)
            print(list_name[img_indx],': ',ssim)
            matplotlib.image.imsave(parent_path + Log + out_dir + ss[-3] + '/' + ss[-1].split(".")[0].split("ASL_")[
                1] + '/t1_' + name + '.png', np.squeeze(t1), cmap='gray')
            matplotlib.image.imsave(parent_path + Log + out_dir + ss[-3] + '/' + ss[-1].split(".")[0].split("ASL_")[
                1] + '/asl_' + name + '_' + '.png', np.squeeze(asl), cmap='gray')
            matplotlib.image.imsave(parent_path + Log + out_dir + ss[-3] + '/' + ss[-1].split(".")[0].split("ASL_")[
                1] + '/pet_' + name + '_' + '.png', np.squeeze(pet), cmap='gray')
            matplotlib.image.imsave(parent_path + Log + out_dir + ss[-3] + '/' + ss[-1].split(".")[0].split("ASL_")[
                1] + '/res_asl' + name + '_' + str(ssim) + '.png', np.squeeze(asl_out2), cmap='gray')
            matplotlib.image.imsave(parent_path + Log + out_dir + ss[-3] + '/' + ss[-1].split(".")[0].split("ASL_")[
                1] + '/res_pet' + name + '_' + str(ssim) + '.png', np.squeeze(pet_out), cmap='gray')

            elapsed = time.time() - tic

    df = pd.DataFrame(list_ssim,
                      columns=pd.Index(['ssim'],
                                       name='Genus')).round(2)
    a = {'HC': list_ssim_HC, 'NC': list_ssim_NC}
    df2 = pd.DataFrame.from_dict(a, orient='index')
    df2.transpose()

    # Create a Pandas Excel writer using XlsxWriter as the engine.
    writer = pd.ExcelWriter(parent_path + Log + out_dir + '/all_ssim.xlsx',
                            engine='xlsxwriter')
    writer2 = pd.ExcelWriter(parent_path + Log + out_dir + '/all_ssim.xlsx',
                             engine='xlsxwriter')
    # Convert the dataframe to an XlsxWriter Excel object.
    df.to_excel(writer, sheet_name='Sheet1')
    df2.to_excel(writer2, sheet_name='Sheet2')
    # Close the Pandas Excel writer and output the Excel file.
    writer.save()
    writer2.save()

    print(parent_path + Log + out_dir + '/all_ssim.xlsx')