Ejemplo n.º 1
0
def test_all_nets(fold, out_dir, Log):
    # densnet_unet_config = [3, 4, 4, 4, 3]
    compression_coefficient = .75
    growth_rate = 4
    # pad_size = 14
    # ext = ''.join(map(str, densnet_unet_config))  # +'_'+str(compression_coefficient)+'_'+str(growth_rate)
    data = 2

    # sample_no=2280000
    # validation_samples=5700
    # no_sample_per_each_itr=3420

    # train_tag='train/'
    # validation_tag='validation/'
    # test_tag='test/'
    # img_name='CT_padded.mha'
    # label_name='GTV_CT_padded.mha'
    # torso_tag='CT_padded_Torso.mha'

    train_tag = 'train/'
    validation_tag = 'validation/'
    test_tag = 'Esophagus/'
    # img_name='CTpadded.mha'
    # label_name='GTV_CTpadded.mha'
    # torso_tag='Torsopadded.mha'

    img_name = ''
    label_name = ''
    torso_tag = ''

    _rd = _read_data(data=data,
                     train_tag=train_tag,
                     validation_tag=validation_tag,
                     test_tag=test_tag,
                     img_name=img_name,
                     label_name=label_name)
    test_path = '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Data-01/' + test_tag
    chckpnt_dir = '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Code/' + Log + log_tag + '/densenet_unet_checkpoints/'

    # test_CTs, test_GTVs ,test_Torsos= _rd.read_imape_path(test_path)

    train_CTs, train_GTVs, train_Torso, train_penalize, \
    validation_CTs, validation_GTVs, validation_Torso, validation_penalize, \
    test_CTs, test_GTVs, test_Torso, test_penalize = _rd.read_data_path(fold=fold)

    # test_CTs=train_CTs
    # test_GTVs=train_GTVs
    # test_Torso=train_Torso
    # test_penalize=train_penalize

    # test_CTs=np.sort(test_CTs)
    # test_GTVs=np.sort(test_GTVs)
    # test_Torso=np.sort(test_Torso)
    # test_penalize=np.sort(test_penalize)
    if test_vali == 1:
        test_CTs = np.sort(validation_CTs)
        test_GTVs = np.sort(validation_GTVs)
        test_Torso = np.sort(validation_Torso)
        test_penalize = np.sort(validation_penalize)
    else:
        test_CTs = np.sort(test_CTs)
        test_GTVs = np.sort(test_GTVs)
        test_Torso = np.sort(test_Torso)
        test_penalize = np.sort(test_penalize)

    lf = _loss_func()
    learning_rate = 1E-4
    # image = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    # label = tf.placeholder(tf.float32, shape=[None, None, None, None, 2])

    image = tf.placeholder(
        tf.float32, shape=[None, ct_cube_size, ct_cube_size, ct_cube_size, 1])
    label = tf.placeholder(
        tf.float32,
        shape=[None, gtv_cube_size, gtv_cube_size, gtv_cube_size, 2])

    ave_vali_acc = tf.placeholder(tf.float32)
    ave_loss_vali = tf.placeholder(tf.float32)

    dropout = tf.placeholder(tf.float32, name='dropout')
    # dropout2=tf.placeholder(tf.float32,name='dropout2')
    is_training = tf.placeholder(tf.bool, name='is_training')
    is_training_bn = tf.placeholder(tf.bool, name='is_training_bn')
    dense_net_dim = tf.placeholder(tf.int32, name='dense_net_dim')
    pnalize = tf.placeholder(tf.float32, shape=[None, None, None, None, 2])
    loss_coef = tf.placeholder(
        tf.float32, shape=[None,
                           2])  # shape: batchno * 2 values for each class

    _dn = _densenet_unet(densnet_unet_config, compression_coefficient,
                         growth_rate)  # create object
    dn_out = _dn.dens_net(image,
                          is_training,
                          dropout_rate1=0,
                          dropout_rate2=0,
                          dim=ct_cube_size,
                          is_training_bn=is_training_bn)
    y = tf.nn.softmax(dn_out)
    yyy = tf.nn.log_softmax(dn_out)

    # y=_dn.vgg(image)
    loss_instance = _loss_func()

    accuracy = loss_instance.accuracy_fn(y, label)
    [dice, edited_dice] = loss_instance.penalize_dice(logits=y,
                                                      labels=label,
                                                      penalize=pnalize)
    # soft_dice_coef=self.loss_instance.soft_dice(logits=y, labels=label)
    cost = tf.reduce_mean(1.0 - dice[1], name="cost")
    # correct_prediction = tf.equal(tf.argmax(y, 4), tf.argmax(label, 4))
    # accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    # cost = tf.reduce_mean(lf.tversky(logits=y, labels=label, alpha=0.9, beta=0.1), name="cost")

    # restore the model
    sess = tf.Session()
    saver = tf.train.Saver()

    ckpt = tf.train.get_checkpoint_state(chckpnt_dir)
    saver.restore(sess, ckpt.model_checkpoint_path)
    _meas = _measure()
    out_path = chckpnt_dir + 'output/'
    copyfile(
        './test_densenet_unet.py',
        '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Code/' + Log +
        out_dir + 'test_densenet_unet.py')

    jj = []
    dd = []
    dice_boxplot0 = []
    dice_boxplot1 = []
    dice_boxplot = []

    jacc_boxplot = []
    jacc_boxplot0 = []
    jacc_boxplot1 = []

    f1_boxplot0 = []
    f1_boxplot1 = []
    f1_boxplot_av = []

    fpr_av = []
    fnr_av = []
    xtickNames = []
    name_list = []

    fpr0 = []
    fpr1 = []

    fnr0 = []
    fnr1 = []

    sp0 = []
    sp1 = []

    recall0 = []
    recall1 = []
    recall_av = []
    presicion0 = []
    presicion1 = []
    presicion_av = []
    img_class = image_class(test_CTs,
                            test_GTVs,
                            test_Torso,
                            test_penalize,
                            bunch_of_images_no=20,
                            is_training=1,
                            patch_window=ct_cube_size,
                            gtv_patch_window=gtv_cube_size)

    for img_indx in range(1, len(test_CTs)):
        print('img_indx:%s' % (img_indx))
        ss = str(test_CTs[img_indx]).split("/")
        name = ss[8] + '_' + ss[9]
        name = (ss[8] + '_' + ss[9] + '_' +
                ss[10].split('%')[0]).split('_CT')[0]
        [
            CT_image, GTV_image, Torso_image, volume_depth, voxel_size, origin,
            direction
        ] = _rd.read_image_seg_volume(test_CTs, test_GTVs, test_Torso,
                                      img_indx, ct_cube_size, gtv_cube_size)

        _zz_img_gt = []
        for _z in (range(
                int(ct_cube_size / 2) + 1,
                CT_image.shape[0] - int(ct_cube_size / 2) + 7,
                int(ct_cube_size) - int(gap) + 1)):
            _xx_img_gt = []
            for _x in (range(
                    int(ct_cube_size / 2) + 1,
                    CT_image.shape[1] - int(ct_cube_size / 2) + 7,
                    int(ct_cube_size) - int(gap) + 1)):
                _yy_img_gt = []
                for _y in (range(
                        int(ct_cube_size / 2) + 1,
                        CT_image.shape[2] - int(ct_cube_size / 2) + 7,
                        int(ct_cube_size) - int(gap) + 1)):

                    ct = CT_image[_z - int(ct_cube_size / 2) - 1:_z +
                                  int(ct_cube_size / 2),
                                  _x - int(ct_cube_size / 2) - 1:_x +
                                  int(ct_cube_size / 2),
                                  _y - int(ct_cube_size / 2) - 1:_y +
                                  int(ct_cube_size / 2)]
                    ct = ct[np.newaxis][..., np.newaxis]
                    gtv = GTV_image[_z - int(gtv_cube_size / 2) - 1:_z +
                                    int(gtv_cube_size / 2),
                                    _x - int(gtv_cube_size / 2) - 1:_x +
                                    int(gtv_cube_size / 2),
                                    _y - int(gtv_cube_size / 2) - 1:_y +
                                    int(gtv_cube_size / 2)]

                    gtv = np.int32(gtv / np.max(GTV_image))

                    gtv = np.eye(2)[gtv]
                    gtv = gtv[np.newaxis]

                    if len(np.where(gtv[0, :, :, :, 1] != 0)[0]):
                        print('o')

                    [acc_vali, loss_vali, out, dn_out1] = sess.run(
                        [accuracy, cost, y, yyy],
                        feed_dict={
                            image: ct,
                            label: gtv,
                            # pnalize: pnlz,
                            # loss_coef:loss_coef_weights,
                            dropout: 1,
                            is_training: False,
                            ave_vali_acc: -1,
                            ave_loss_vali: -1,
                            dense_net_dim: ct_cube_size,
                            is_training_bn: False,
                        })

                    if len(_yy_img_gt) == 0:
                        _yy_img_gt = np.int32(gtv[0, :, :, :, 1])
                        _yy_img = np.int32(out[0, :, :, :, 1])

                        _yy_img_ct = CT_image[_z - int(gtv_cube_size / 2) -
                                              1:_z + int(gtv_cube_size / 2),
                                              _x - int(gtv_cube_size / 2) -
                                              1:_x + int(gtv_cube_size / 2),
                                              _y - int(gtv_cube_size / 2) -
                                              1:_y + int(gtv_cube_size / 2)]
                    else:
                        _yy_img_gt = np.concatenate(
                            (_yy_img_gt, gtv[0, :, :, :, 1]), axis=2)
                        _yy_img = np.concatenate((_yy_img, out[0, :, :, :, 1]),
                                                 axis=2)

                        _yy_img_ct = np.concatenate(
                            (_yy_img_ct,
                             CT_image[_z - int(gtv_cube_size / 2) - 1:_z +
                                      int(gtv_cube_size / 2),
                                      _x - int(gtv_cube_size / 2) - 1:_x +
                                      int(gtv_cube_size / 2),
                                      _y - int(gtv_cube_size / 2) - 1:_y +
                                      int(gtv_cube_size / 2)]),
                            axis=2)

                if len(_xx_img_gt) == 0:
                    _xx_img_gt = _yy_img_gt
                    _xx_img = _yy_img
                    _xx_img_ct = _yy_img_ct
                else:
                    _xx_img_gt = np.concatenate((_xx_img_gt, _yy_img_gt),
                                                axis=1)
                    _xx_img = np.concatenate((_xx_img, _yy_img), axis=1)
                    _xx_img_ct = np.concatenate((_xx_img_ct, _yy_img_ct),
                                                axis=1)

            if len(_zz_img_gt) == 0:
                _zz_img_gt = _xx_img_gt
                _zz_img = _xx_img
                _zz_img_ct = _xx_img_ct
            else:
                _zz_img_gt = np.concatenate((_zz_img_gt, _xx_img_gt), axis=0)
                _zz_img = np.concatenate((_zz_img, _xx_img), axis=0)
                _zz_img_ct = np.concatenate((_zz_img_ct, _xx_img_ct), axis=0)

        name_list.append(name)

        #
        [TP, TN, FP, FN] = tp_tn_fp_fn(np.round(_zz_img), _zz_img_gt)

        f1 = f1_measure(TP, TN, FP, FN)
        print('%s: f1:%f,f1:%f' % (name, f1[0], f1[1]))
        f1_boxplot0.append(f1[0])
        f1_boxplot1.append(f1[1])
        f1_boxplot_av.append((f1[0] + f1[1]) / 2)

        fpr = FPR(TP, TN, FP, FN)
        fpr0.append(fpr[0])
        fpr1.append(fpr[1])
        fpr_av.append((fpr[0] + fpr[1]) / 2)

        fnr = FNR(TP, TN, FP, FN)
        fnr0.append(fnr[0])
        fnr1.append(fnr[1])
        fnr_av.append((fnr[0] + fnr[1]) / 2)

        precision = Precision(TP, TN, FP, FN)
        presicion0.append(precision[0])
        presicion1.append(precision[1])
        presicion_av.append((precision[0] + precision[1]) / 2)

        recall = Recall(TP, TN, FP, FN)
        recall0.append(recall[0])
        recall1.append(recall[1])
        recall_av.append((recall[0] + recall[1]) / 2)

        _zz_img1 = np.round(_zz_img)
        segmentation = np.asarray(_zz_img1)
        predicted_label = sitk.GetImageFromArray(segmentation.astype(np.uint8))
        predicted_label.SetDirection(direction=direction)
        predicted_label.SetOrigin(origin=origin)
        predicted_label.SetSpacing(spacing=voxel_size)
        sitk.WriteImage(
            predicted_label,
            '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Code/' +
            Log + out_dir + name + '_result.mha')

        segmentation = np.asarray(_zz_img)
        predicted_label = sitk.GetImageFromArray(
            segmentation.astype(np.float32))
        predicted_label.SetDirection(direction=direction)
        predicted_label.SetOrigin(origin=origin)
        predicted_label.SetSpacing(spacing=voxel_size)
        sitk.WriteImage(
            predicted_label,
            '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Code/' +
            Log + out_dir + name + '_fuzzy.mha')

        segmentation = np.asarray(_zz_img_gt)
        predicted_label = sitk.GetImageFromArray(segmentation.astype(np.uint8))
        predicted_label.SetDirection(direction=direction)
        predicted_label.SetOrigin(origin=origin)
        predicted_label.SetSpacing(spacing=voxel_size)
        sitk.WriteImage(
            predicted_label,
            '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Code/' +
            Log + out_dir + name + '_gtv.mha')

        segmentation = np.asarray(_zz_img_ct)
        predicted_label = sitk.GetImageFromArray(segmentation.astype(np.short))
        predicted_label.SetDirection(direction=direction)
        predicted_label.SetOrigin(origin=origin)
        predicted_label.SetSpacing(spacing=voxel_size)
        sitk.WriteImage(
            predicted_label,
            '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Code/' +
            Log + out_dir + name + '_ct.mha')
        # output(filename, sheet, list1, list2, x, y, z)
        # output('/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Code/'+Log+out_dir +'file.xls','sheet 1',fnr,fpr,'a','b','c')
        print('end')

    f1_bp0 = []
    f1_bp1 = []
    f1_bp_av = []
    f1_bp0.append((f1_boxplot0))
    f1_bp1.append((f1_boxplot1))
    f1_bp_av.append((f1_boxplot_av))
    plt.figure()
    plt.boxplot(f1_bp0, 0, '')
    plt.title('Tumor Dice value for all the images' + plot_tag)
    plt.savefig(
        '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Code/' + Log +
        out_dir + name + 'f1_bp_tumor.png')

    plt.figure()
    plt.boxplot(f1_bp1, 0, '')
    plt.title('Background Dice value for all the images ' + plot_tag)
    plt.savefig(
        '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Code/' + Log +
        out_dir + name + 'f1_bp_background.png')

    plt.figure()
    plt.boxplot(f1_bp_av, 0, '')
    plt.title('Average Dice value for all the images' + plot_tag)
    plt.savefig(
        '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Code/' + Log +
        out_dir + name + 'f1_bp_average.png')
    #----------------------
    fpr_bp0 = []
    fpr_bp0.append((fpr0))
    plt.figure()
    plt.boxplot(fpr_bp0, 0, '')
    plt.title('FPR Tumor value for all the images' + plot_tag)
    plt.savefig(
        '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Code/' + Log +
        out_dir + name + 'fpr_bp_tumor.png')

    fpr_bp1 = []
    fpr_bp1.append((fpr1))
    plt.figure()
    plt.boxplot(fpr_bp1, 0, '')
    plt.title('FPR Background value for all the images' + plot_tag)
    plt.savefig(
        '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Code/' + Log +
        out_dir + name + 'fpr_bp_background.png')

    fpr_bp = []
    fpr_bp.append((fpr_av))
    plt.figure()
    plt.boxplot(fpr_bp, 0, '')
    plt.title('FPR Average value for all the images' + plot_tag)
    plt.savefig(
        '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Code/' + Log +
        out_dir + name + 'fpr_bp_average.png')

    #----------------------
    fnr_bp0 = []
    fnr_bp0.append((fnr0))
    plt.figure()
    plt.boxplot(fnr_bp0, 0, '')
    plt.title('FNR Tumor value for all the images' + plot_tag)
    plt.savefig(
        '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Code/' + Log +
        out_dir + name + 'fnr_bp_tumor.png')

    fnr_bp1 = []
    fnr_bp1.append((fnr1))
    plt.figure()
    plt.boxplot(fnr_bp1, 0, '')
    plt.title('FNR Background value for all the images' + plot_tag)
    plt.savefig(
        '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Code/' + Log +
        out_dir + name + 'fnr_bp_background.png')

    fnr_bp = []
    fnr_bp.append((fnr_av))
    plt.figure()
    plt.boxplot(fnr_bp, 0, '')
    plt.title('FNR Average value for all the images' + plot_tag)
    plt.savefig(
        '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Code/' + Log +
        out_dir + name + 'fnr_bp_average.png')
    #----------------------
    pres_bp0 = []
    pres_bp0.append((presicion0))
    plt.figure()
    plt.boxplot(pres_bp0, 0, '')
    plt.title('Precision value for all the images' + plot_tag)
    plt.savefig(
        '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Code/' + Log +
        out_dir + name + 'precision_bp_tumor.png')

    pres_bp1 = []
    pres_bp1.append((presicion1))
    plt.figure()
    plt.boxplot(pres_bp1, 0, '')
    plt.title('Precision Background value for all the images' + plot_tag)
    plt.savefig(
        '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Code/' + Log +
        out_dir + name + 'precision_bp_background.png')

    pres_bp = []
    pres_bp.append((presicion_av))
    plt.figure()
    plt.boxplot(pres_bp, 0, '')
    plt.title('Precision Average value for all the images' + plot_tag)
    plt.savefig(
        '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Code/' + Log +
        out_dir + name + 'precision_bp_average.png')
    #----------------------
    recall_bp0 = []
    recall_bp0.append((recall0))
    plt.figure()
    plt.boxplot(recall_bp0, 0, '')
    plt.title('Recall value for all the images' + plot_tag)
    plt.savefig(
        '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Code/' + Log +
        out_dir + name + 'recall_bp_tumor.png')

    recall_bp1 = []
    recall_bp1.append((recall1))
    plt.figure()
    plt.boxplot(recall_bp1, 0, '')
    plt.title('Recall Background value for all the images' + plot_tag)
    plt.savefig(
        '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Code/' + Log +
        out_dir + name + 'recall_bp_background.png')

    recall_bp = []
    recall_bp.append((recall_av))
    plt.figure()
    plt.boxplot(recall_bp, 0, '')
    plt.title('Recall Average value for all the images' + plot_tag)
    plt.savefig(
        '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Code/' + Log +
        out_dir + name + 'recall_bp_average.png')
    #----------------------
    plt.figure()
    d_bp = []
    d_bp.append((f1_boxplot0))
    xs = [i for i, _ in enumerate(name_list)]

    plt.bar(xs, f1_boxplot0, align='center')
    plt.xticks(xs, name_list, rotation='vertical')
    plt.margins(.05)
    plt.subplots_adjust(bottom=0.45)
    plt.title('Dice all images' + plot_tag)
    plt.grid()
    plt.savefig(
        '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Code/' + Log +
        out_dir + name + 'dice_bar.png')

    #----------------------
    plt.figure()

    fnr_bar0 = []
    fnr_bar0.append((fnr0))
    xs = [i for i, _ in enumerate(name_list)]

    plt.bar(xs, fnr0, align='center')
    plt.xticks(xs, name_list, rotation='vertical')
    plt.margins(.05)
    plt.subplots_adjust(bottom=0.25)
    plt.title('FNR Background all images' + plot_tag)
    plt.grid()
    plt.savefig(
        '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Code/' + Log +
        out_dir + name + 'fnr_background_bar.png')

    #----------------------
    plt.figure()

    fnr_bar1 = []
    fnr_bar1.append((fnr1))
    xs = [i for i, _ in enumerate(name_list)]

    plt.bar(xs, fnr1, align='center')
    plt.xticks(xs, name_list, rotation='vertical')
    plt.margins(.05)
    plt.subplots_adjust(bottom=0.25)
    plt.title('FNR Tumor all images' + plot_tag)
    plt.grid()
    plt.savefig(
        '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Code/' + Log +
        out_dir + name + 'fnr_tumor_bar.png')

    #----------------------
    plt.figure()

    fpr_bar0 = []
    fpr_bar0.append((fpr0))
    xs = [i for i, _ in enumerate(name_list)]

    plt.bar(xs, fpr0, align='center')
    plt.xticks(xs, name_list, rotation='vertical')
    plt.margins(.05)
    plt.subplots_adjust(bottom=0.25)
    plt.title('FPR Background all images' + plot_tag)
    plt.grid()
    plt.savefig(
        '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Code/' + Log +
        out_dir + name + 'fpr_background_bar.png')

    #----------------------
    plt.figure()

    fpr_bar1 = []
    fpr_bar1.append((fpr1))
    xs = [i for i, _ in enumerate(name_list)]

    plt.bar(xs, fpr1, align='center')
    plt.xticks(xs, name_list, rotation='vertical')
    plt.margins(.05)
    plt.subplots_adjust(bottom=0.25)
    plt.title('FPR tumor all images' + plot_tag)
    plt.grid()
    plt.savefig(
        '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Code/' + Log +
        out_dir + name + 'fpr_tumor_bar.png')

    #----------------------
    plt.figure()

    recall_bar0 = []
    recall_bar0.append((recall0))
    xs = [i for i, _ in enumerate(name_list)]

    plt.bar(xs, recall0, align='center')
    plt.xticks(xs, name_list, rotation='vertical')
    plt.margins(.05)
    plt.subplots_adjust(bottom=0.25)
    plt.title('Recall Background all images' + plot_tag)
    plt.grid()
    plt.savefig(
        '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Code/' + Log +
        out_dir + name + 'recall_background_bar.png')

    #----------------------
    plt.figure()

    recall_bar = []
    recall_bar.append((recall1))
    xs = [i for i, _ in enumerate(name_list)]

    plt.bar(xs, recall1, align='center')
    plt.xticks(xs, name_list, rotation='vertical')
    plt.margins(.05)
    plt.subplots_adjust(bottom=0.25)
    plt.title('Recall tumor all images' + plot_tag)
    plt.grid()
    plt.savefig(
        '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Code/' + Log +
        out_dir + name + 'recall_tumor_bar.png')

    #----------------------
    plt.figure()

    recall_bar = []
    recall_bar.append((recall1))
    xs = [i for i, _ in enumerate(name_list)]

    plt.bar(xs, recall1, align='center')
    plt.xticks(xs, name_list, rotation='vertical')
    plt.margins(.05)
    plt.subplots_adjust(bottom=0.25)
    plt.title('Recall Average all images' + plot_tag)
    plt.grid()
    plt.savefig(
        '/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Code/' + Log +
        out_dir + name + 'recall_average_bar.png')
        shape=[batch_no, patch_window, patch_window, patch_window, 1])
    label1 = tf.placeholder(tf.float32,
                            shape=[
                                batch_no, label_patchs_size, label_patchs_size,
                                label_patchs_size, 1
                            ])
    # img_row1 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    # label1 = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    input_dim = tf.placeholder(tf.int32, name='unet_input_dim')
    is_training = tf.placeholder(tf.bool, name='unet_is_training')
    unet = _unet(trainable=False, file_name=ckpoint_path)
    y = unet.unet(img_row1=img_row1,
                  input_dim=input_dim,
                  is_training=is_training)
    _rd = _read_data(data=data,
                     img_name=img_name,
                     label_name=label_name,
                     dataset_path=data_path)

    train_data, validation_data, test_data = _rd.read_data_path(fold=fold)
    input_cube_size = 47
    gt_cube_size = input_cube_size
    test_vali = 0
    if test_vali == 1:
        test_set = validation_data
    else:
        test_set = test_data
    img_class = image_class(test_set,
                            bunch_of_images_no=1,
                            is_training=1,
                            patch_window=input_cube_size,
                            sample_no_per_bunch=1,
    def run_net(self):


        '''read 2d images from the data:'''
        two_dim=True
        resampled_path = '/exports/lkeb-hpc/syousefi/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Data-01/21data1100data2-v3/'
        _rd = _read_data(data=self.data,train_tag=self.train_tag, validation_tag=self.validation_tag, test_tag=self.test_tag,
                         img_name=self.img_name, label_name=self.label_name,torso_tag=self.torso_tag,resampled_path=resampled_path)
        # _rd = _read_data(train_tag='prostate_train/',validation_tag='prostate_validation/',test_tag='prostate_test/',
        #          img_name='.mha',label_name='Bladder.mha',torso_tag=self.torso_tag)
        # _rd = _read_data(train_tag='train/', validation_tag='validation/', test_tag='test/',
        #                  img_name='CT.mha', label_name='GTV_CT.mha',torso_tag=self.torso_tag)

        flag=False
        self.alpha_coeff=1



        '''read path of the images for train, test, and validation'''
        train_CTs, train_GTVs, train_Torso, train_penalize, \
        validation_CTs, validation_GTVs, validation_Torso, validation_penalize, \
        test_CTs, test_GTVs, test_Torso, test_penalize=_rd.read_data_path(fold=self.fold)
        self.img_width = 500
        self.img_height = 500
        # ======================================
        bunch_of_images_no=20
        sample_no=500
        _image_class_vl = image_class(validation_CTs, validation_GTVs, validation_Torso,validation_penalize
                                      , bunch_of_images_no=bunch_of_images_no,  is_training=0,
                                      patch_window=self.patch_window)
        _patch_extractor_thread_vl = _patch_extractor_thread(_image_class=_image_class_vl,
                                                             sample_no=sample_no, patch_window=self.patch_window,
                                                             GTV_patchs_size=self.GTV_patchs_size,
                                                             tumor_percent=self.tumor_percent,
                                                             other_percent=self.other_percent,
                                                             img_no=bunch_of_images_no,
                                                             mutex=settings.mutex,is_training=0,vl_sample_no=self.validation_samples
                                                             )
        _fill_thread_vl = fill_thread(validation_CTs,
                                      validation_GTVs,
                                      validation_Torso,
                                      validation_penalize,
                                      _image_class_vl,
                                      sample_no=sample_no,
                                      total_sample_no=self.validation_samples,
                                      patch_window=self.patch_window,
                                      GTV_patchs_size=self.GTV_patchs_size,
                                      img_width=self.img_width, img_height=self.img_height,
                                      mutex=settings.mutex,
                                      tumor_percent=self.tumor_percent,
                                      other_percent=self.other_percent,
                                      is_training=0,
                                      patch_extractor=_patch_extractor_thread_vl,
                                      fold=self.fold)


        _fill_thread_vl.start()
        _patch_extractor_thread_vl.start()
        #time.sleep(1)
        _read_thread_vl = read_thread(_fill_thread_vl, mutex=settings.mutex,
                                      validation_sample_no=self.validation_samples, is_training=0)
        _read_thread_vl.start()
        # ======================================
        bunch_of_images_no = 24
        sample_no=240
        _image_class = image_class(train_CTs, train_GTVs, train_Torso,train_penalize
                                   , bunch_of_images_no=bunch_of_images_no,is_training=1,patch_window=self.patch_window
                                   )
        patch_extractor_thread = _patch_extractor_thread(_image_class=_image_class,
                                                         sample_no=sample_no, patch_window=self.patch_window,
                                                         GTV_patchs_size=self.GTV_patchs_size,
                                                         tumor_percent=self.tumor_percent,
                                                         other_percent=self.other_percent,
                                                         img_no=bunch_of_images_no,
                                                         mutex=settings.mutex,is_training=1)
        _fill_thread = fill_thread(train_CTs, train_GTVs, train_Torso,train_penalize,
                                   _image_class,
                                   sample_no=sample_no,total_sample_no=self.sample_no,
                                   patch_window=self.patch_window,
                                   GTV_patchs_size=self.GTV_patchs_size,
                                   img_width=self.img_width,
                                   img_height=self.img_height,mutex=settings.mutex,
                                   tumor_percent=self.tumor_percent,
                                   other_percent=self.other_percent,is_training=1,
                                   patch_extractor=patch_extractor_thread,
                                   fold=self.fold)

        _fill_thread.start()
        patch_extractor_thread.start()

        _read_thread = read_thread(_fill_thread,mutex=settings.mutex,is_training=1)
        _read_thread.start()
        # ======================================
        # pre_bn=tf.placeholder(tf.float32,shape=[None,None,None,None,None])
        # image=tf.placeholder(tf.float32,shape=[self.batch_no,self.patch_window,self.patch_window,self.patch_window,1])
        # label=tf.placeholder(tf.float32,shape=[self.batch_no_validation,self.GTV_patchs_size,self.GTV_patchs_size,self.GTV_patchs_size,2])
        # loss_coef=tf.placeholder(tf.float32,shape=[self.batch_no_validation,1,1,1])

        image = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
        label = tf.placeholder(tf.float32, shape=[None, None, None, None, 2])
        # pnalize = tf.placeholder(tf.float32, shape=[None, None, None, None,2])
        # loss_coef = tf.placeholder(tf.float32, shape=[None, 2]) # shape: batchno * 2 values for each class
        alpha = tf.placeholder(tf.float32, name='alpha') # background coeff
        beta = tf.placeholder(tf.float32, name='beta') # tumor coeff



        ave_vali_acc=tf.placeholder(tf.float32)
        ave_loss_vali=tf.placeholder(tf.float32)
        ave_dsc_vali=tf.placeholder(tf.float32)

        dropout=tf.placeholder(tf.float32,name='dropout')
        is_training = tf.placeholder(tf.bool, name='is_training')
        is_training_bn = tf.placeholder(tf.bool, name='is_training_bn')
        dense_net_dim = tf.placeholder(tf.int32, name='dense_net_dim')

        # _u_net=_unet()
        # _u_net.unet(image)
        _dn = _densenet_unet(self.densnet_unet_config,self.compression_coefficient,self.growth_rate) #create object
        y=_dn.dens_net(image=image,is_training=is_training,dropout_rate1=0,dropout_rate2=0,dim=dense_net_dim,is_training_bn=is_training_bn)
        # y = _dn.vgg(image)

        y_dirX = ((y[:, int(self.GTV_patchs_size / 2), :, :, 0, np.newaxis]))
        label_dirX = (label[:, int(self.GTV_patchs_size / 2), :, :, 0, np.newaxis])
        # pnalize_dirX =   (pnalize[:,16,:,:,0,np.newaxis])
        image_dirX = ((image[:, int(self.patch_window / 2), :, :, 0, np.newaxis]))
        # x_Fixed = label[0,np.newaxis,:,:,0,np.newaxis]#tf.expand_dims(tf.expand_dims(y[0,10, :, :, 1], 0), -1)
        # x_Deformed = tf.expand_dims(tf.expand_dims(y[0,10, :, :, 1], 0), -1)


        show_img=tf.nn.softmax(y)[:, int(self.GTV_patchs_size / 2) , :, :, 0, np.newaxis]
        tf.summary.image('outprunut',show_img  , 3)
        tf.summary.image('output without softmax',y_dirX ,3)
        tf.summary.image('groundtruth', label_dirX,3)
        # tf.summary.image('pnalize', pnalize_dirX,3)
        tf.summary.image('image',image_dirX ,3)
        sess = tf.Session()
        log_extttt=''#self.log_ext.split('_')[0]+'01'
        train_writer = tf.summary.FileWriter(self.LOGDIR + '/train' + log_extttt,graph=tf.get_default_graph())
        validation_writer = tf.summary.FileWriter(self.LOGDIR + '/validation' + log_extttt, graph=sess.graph)
        # y=_dn.vgg(image)
        extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        saver=tf.train.Saver(tf.global_variables(), max_to_keep=1000)

        # copyfile('./functions/densenet_unet.py', self.LOGDIR + 'densenet_unet.py')
        copyfile('./functions/densenet_classify2.py', self.LOGDIR + 'densenet_classify2.py')
        copyfile('./functions/image_class.py', self.LOGDIR + 'image_class.py')
        copyfile('./functions/read_data.py', self.LOGDIR + 'read_data.py')
        copyfile('./run_allnets.py', self.LOGDIR + 'run_allnets.py')
        copyfile('./functions/fill_thread.py', self.LOGDIR + 'fill_thread.py')
        copyfile('./functions/read_thread.py', self.LOGDIR + 'read_thread.py')
        copyfile('./functions/loss_func.py', self.LOGDIR + 'loss_func.py')


        '''AdamOptimizer:'''
        with tf.name_scope('cost'):
            # cost = 1-dsc_fn(y, label, 1E-4)
            # cost = -tf.reduce_sum(label*tf.log(tf.clip_by_value(y,1e-10,1.0)))
            # cost = tf.reduce_mean(tf.losses.softmax_cross_entropy(logits=y, onehot_labels=label), name="cost")#
            # cost = tf.reduce_mean(tversky(logits=y, labels=label, alpha=self.alpha, beta=self.beta), name="cost")
            # [dice,edited_dice]=self.loss_instance.penalize_dice(logits=y, labels=label,penalize=[])
            # soft_dice_coef=self.loss_instance.soft_dice(logits=y, labels=label)

            [soft_dice_coef,logt,lbl]=self.loss_instance.soft_dice(logits=y, labels=label)
            cost = tf.reduce_mean(1.0 - soft_dice_coef[1], name="cost")

            # wce=self.loss_instance.weighted_cross_entrophy_loss( logits=y, labels=label)
            # cost = tf.reduce_mean(wce, name="cost")


            # cost_before = tf.reduce_mean(1.0 -edited_dice[1], name="cost_before")
            # cost = tf.reduce_mean(self.loss_instance.f1_measure(logits=y, labels=label), name="cost")
            # [f1_score,weighted_f1_score]=self.loss_instance.f1_measure(logits=y, labels=label,
            #                                        alpha=alpha,beta=beta)
            # cost = tf.reduce_mean(1.0 -weighted_f1_score, name="cost")

        tf.summary.scalar("cost", cost)
        # tf.summary.scalar("cost_before", cost_before)
        # tf.summary.scalar("recall_precision", recall_precision)
        f1_measure = self.loss_instance.f1_measure(logits=y, labels=label)
        tf.summary.scalar("dice_bakground", f1_measure[0])
        tf.summary.scalar("dice_tumor", f1_measure[1])

        pwc = self.loss_instance.PWC(y, label)
        tf.summary.scalar("pwc_bakground", pwc[0])
        tf.summary.scalar("pwc_tumor", pwc[1])

        recall = self.loss_instance.Recall(y, label)
        tf.summary.scalar("recall_bakground", recall[0])
        tf.summary.scalar("recall_tumor", recall[1])

        precision = self.loss_instance.Precision(y, label)
        tf.summary.scalar("precision_bakground", precision[0])
        tf.summary.scalar("precision_tumor", precision[1])

        fpr = self.loss_instance.FPR(y, label)
        tf.summary.scalar("FPR_bakground", fpr[0])
        tf.summary.scalar("FPR_tumor", fpr[1])

        fnr = self.loss_instance.FNR(y, label)
        tf.summary.scalar("FNR_bakground", fnr[0])
        tf.summary.scalar("FNR_tumor", fnr[1])

        extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(extra_update_ops):
            # optimizer_tmp = tf.train.AdamOptimizer(self.learning_rate)
            optimizer_tmp = RAdam(learning_rate=self.learning_rate)
            optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer_tmp)
            optimizer = optimizer.minimize(cost)

        with tf.name_scope('validation'):
            average_validation_accuracy=ave_vali_acc
            average_validation_loss=ave_loss_vali
            average_dsc_loss=ave_dsc_vali
        tf.summary.scalar("average_validation_accuracy",average_validation_accuracy)
        tf.summary.scalar("average_validation_loss",average_validation_loss)
        tf.summary.scalar("average_dsc_loss",average_dsc_loss)

        with tf.name_scope('accuracy'):
            # accuracy=dsc_fn(y, label,1E-4)
            accuracy=self.loss_instance.accuracy_fn(y, label)
            # f1_score=self.loss_instance.f1_measure(y, label)
            # accuracy=tf.reduce_mean(f1_score)
        tf.summary.scalar("accuracy", accuracy)
        # tf.summary.scalar("f1_score1",f1_score[0])
        # tf.summary.scalar("f1_score2",f1_score[1])


        sess.run(tf.global_variables_initializer())
        # train_writer.add_graph(sess.graph)
        logging.debug('total number of variables %s' % (
        np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])))

        summ=tf.summary.merge_all()
        loadModel = 0
        point = 0
        itr1 = 0
        if loadModel:
            chckpnt_dir='/srv/2-lkeb-17-dl01/syousefi/TestCode/EsophagusProject/Code/Log_2018-08-15/dilated-mid/22322_1_4-without-firstlayers-01/densenet_unet_checkpoints/'
            ckpt = tf.train.get_checkpoint_state(chckpnt_dir)
            saver.restore(sess, ckpt.model_checkpoint_path)
            point=np.int16(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
            itr1=point


        # patch_radius = 49
        '''loop for epochs'''

        for epoch in range(self.total_epochs):
            while self.no_sample_per_each_itr*int(point/self.no_sample_per_each_itr)<self.sample_no:
                print('0')
                # self.save_file(self.train_acc_file, 'epoch: %d\n' % (epoch))
                # self.save_file(self.validation_acc_file, 'epoch: %d\n' % (epoch))
                print("epoch #: %d" %(epoch))
                startTime = time.time()

                step = 0

                self.beta_coeff=1+1 * np.exp(-point/2000)

                # if self.beta_coeff>1:
                #     self.beta_coeff=self.beta_coeff-self.beta_rate*self.beta_coeff
                #     if self.beta_coeff<1:
                #         self.beta_coeff=1

                # =============validation================
                if itr1 % self.display_validation_step ==0:
                    '''Validation: '''
                    loss_validation = 0
                    acc_validation = 0
                    validation_step = 0
                    dsc_validation=0

                    while (validation_step * self.batch_no_validation <settings.validation_totalimg_patch):


                        [validation_CT_image, validation_GTV_image] =_image_class_vl.return_patches_validation( validation_step * self.batch_no_validation,
                                                                                                                (validation_step + 1) *self.batch_no_validation)
                        if (len(validation_CT_image)<self.batch_no_validation) | (len(validation_GTV_image)<self.batch_no_validation) :
                            _read_thread_vl.resume()
                            time.sleep(0.5)
                            # print('sleep 3 validation')
                            continue

                        validation_CT_image_patchs = validation_CT_image
                        validation_GTV_label = validation_GTV_image



                        [acc_vali, loss_vali,dsc_vali] = sess.run([accuracy, cost,f1_measure],
                                                         feed_dict={image: validation_CT_image_patchs,
                                                                    label: validation_GTV_label,
                                                                    dropout: 1,
                                                                    is_training: False,
                                                                    ave_vali_acc: -1,
                                                                    ave_loss_vali: -1,
                                                                    ave_dsc_vali:-1,
                                                                    dense_net_dim: self.patch_window,
                                                                    is_training_bn:False,
                                                                    alpha:1,
                                                                    beta:1})



                        acc_validation += acc_vali
                        loss_validation += loss_vali
                        dsc_validation+=dsc_vali[1]
                        validation_step += 1
                        if np.isnan(dsc_validation) or np.isnan(loss_validation) or np.isnan(acc_validation):
                            print('nan problem')
                        process = psutil.Process(os.getpid())

                        # print('%d - > %d:  acc_validation: %f, loss_validation: %f, no_list: %d, memory_percent: %s, memory_info: %s' % (validation_step,validation_step * self.batch_no_validation
                        #                                                                         ,acc_vali, loss_vali,len(settings.bunch_GTV_patches_vl2),str(process.memory_percent()),str(process.memory_info())))
                        print(
                            '%d - > %d:  acc_validation: %f, loss_validation: %f, memory_percent: %4s' % (
                                validation_step, validation_step * self.batch_no_validation
                                , acc_vali, loss_vali, str(process.memory_percent()),
                            ))

                            # end while
                    settings.queue_isready_vl = False
                    acc_validation = acc_validation / (validation_step)
                    loss_validation = loss_validation / (validation_step)
                    dsc_validation = dsc_validation / (validation_step)
                    if np.isnan(dsc_validation) or np.isnan(loss_validation) or np.isnan(acc_validation):
                        print('nan problem')
                    _fill_thread_vl.kill_thread()
                    print('******Validation, step: %d , accuracy: %.4f, loss: %f*******' % (
                    itr1, acc_validation, loss_validation))
                    [sum_validation] = sess.run([summ],
                                                feed_dict={image: validation_CT_image_patchs,
                                                           label: validation_GTV_label,
                                                           dropout: 1,
                                                           is_training: False,
                                                           ave_vali_acc: acc_validation,
                                                           ave_loss_vali: loss_validation,
                                                           ave_dsc_vali:dsc_validation,
                                                           dense_net_dim: self.patch_window,
                                                           is_training_bn: False,
                                                           alpha: 1,
                                                           beta: 1
                                                           })
                    validation_writer.add_summary(sum_validation, point)
                    print('end of validation---------%d' % (point))

                    # end if


                '''loop for training batches'''
                while(step*self.batch_no<self.no_sample_per_each_itr):

                    [train_CT_image_patchs, train_GTV_label, train_Penalize_patch,loss_coef_weights] = _image_class.return_patches( self.batch_no)

                    # [train_CT_image_patchs, train_GTV_label] = _image_class.return_patches_overfit( step*self.batch_no,(step+1)*self.batch_no)


                    if (len(train_CT_image_patchs)<self.batch_no)|(len(train_GTV_label)<self.batch_no):
                        #|(len(train_Penalize_patch)<self.batch_no):
                        time.sleep(0.5)
                        _read_thread.resume()
                        continue



                    [acc_train1, loss_train1, optimizing,out,dsc_train11] = sess.run([accuracy, cost, optimizer,y,f1_measure],
                                                                     feed_dict={image: train_CT_image_patchs,
                                                                                label: train_GTV_label,
                                                                                # pnalize: train_Penalize_patch,
                                                                                # loss_coef: loss_coef_weights,
                                                                                dropout: self.dropout_keep,
                                                                                is_training: True,
                                                                                ave_vali_acc: -1,
                                                                                ave_loss_vali: -1,
                                                                                ave_dsc_vali: -1,
                                                                                dense_net_dim: self.patch_window,
                                                                                is_training_bn: True,
                                                                                alpha: self.alpha_coeff,
                                                                                beta: self.beta_coeff
                                                                                })
                    dsc_train1=dsc_train11[1]

                    self.x_hist=self.x_hist+1
                    # np.hstack((self.x_hist, [np.ceil(
                    #     len(np.where(train_GTV_label[i, :, :, :, 1] == 1)[0]) / pow(self.GTV_patchs_size, 3) * 10) * 10
                    #                     for i in range(self.batch_no)]))


                    [sum_train] = sess.run([summ],
                                           feed_dict={image: train_CT_image_patchs,
                                                      label: train_GTV_label,
                                                      # pnalize: train_Penalize_patch,
                                                      # loss_coef: loss_coef_weights,
                                                      dropout: self.dropout_keep, is_training: True,
                                                      ave_vali_acc: acc_train1,
                                                      ave_loss_vali: loss_train1,
                                                      ave_dsc_vali: dsc_train1,
                                                      dense_net_dim: self.patch_window,
                                                      is_training_bn: True,
                                                      alpha: self.alpha_coeff,
                                                      beta: self.beta_coeff
                                                      })
                    train_writer.add_summary(sum_train,point)
                    step = step + 1



                    # if step==100:
                    #     test_in_train(test_CTs, test_GTVs, sess, accuracy, cost, y, image, label, dropout, is_training,
                    #                   ave_vali_acc, ave_loss_vali, dense_net_dim)
                    process = psutil.Process(os.getpid())
                    # print('point: %d, step*self.batch_no:%f , LR: %.15f, acc_train1:%f, loss_train1:%f,cost_before:%, memory_percent: %4s' % (int((self.x_hist)),
                    # step * self.batch_no,self.learning_rate, acc_train1, loss_train1,cost_b,str(process.memory_percent())))
                    print(
                        'point: %d, step*self.batch_no:%f , LR: %.15f, acc_train1:%f, loss_train1:%f,memory_percent: %4s' % (
                        int((point)),
                        step * self.batch_no, self.learning_rate, acc_train1, loss_train1,
                        str(process.memory_percent())))


                    # print('------------step:%d'%((self.no_sample_per_each_itr/self.batch_no)*itr1+step))
                    point=int((point))#(self.no_sample_per_each_itr/self.batch_no)*itr1+step
                    #
                    # [out] = sess.run([y],
                    #                  feed_dict={image: train_CT_image_patchs,
                    #                             label: train_GTV_label,
                    #                             dropout: self.dropout_keep,
                    #                             is_training: True,
                    #                             ave_vali_acc: -1, ave_loss_vali: -1,
                    #                             dense_net_dim: self.patch_window})
                    # plt.close('all')
                    # imgplot = plt.imshow((train_GTV_label[30][:][:][:])[8, :, :, 1], cmap='gray')
                    # plt.figure()
                    # imgplot = plt.imshow((out[30][:][:][:])[8,:,:,1], cmap='gray')
                    if point%100==0:
                        '''saveing model inter epoch'''
                        chckpnt_path = os.path.join(self.chckpnt_dir,
                                                    ('densenet_unet_inter_epoch%d_point%d.ckpt' % (epoch, point)))
                        saver.save(sess, chckpnt_path, global_step=point)

                        # self.test_patches(test_CTs, test_GTVs, test_Torso, epoch,point, _rd,
                        #                   sess, accuracy, cost, y, image, label, dropout, is_training, ave_vali_acc,
                        #                   ave_loss_vali,
                        #                   dense_net_dim)
                    # if itr1!=0 and itr1%100==0 and itr1<1000:
                    #      self.learning_rate = self.learning_rate * np.exp(-.005*itr1)

                    # if point==400:
                    #      self.learning_rate = self.learning_rate * .1
                    # if point%4000==0:
                    #      self.learning_rate = self.learning_rate * self.learning_decay
                    # if point%10000==0:
                    #     self.learning_rate = self.learning_rate * .9

                    itr1 = itr1 + 1
                    point=point+1







            endTime = time.time()




            #==============end of epoch:



            '''saveing model after each epoch'''
            chckpnt_path = os.path.join(self.chckpnt_dir, 'densenet_unet.ckpt')
            saver.save(sess, chckpnt_path, global_step=epoch)


            print("End of epoch----> %d, elapsed time: %d" % (epoch, endTime - startTime))