def main():
    #import data
    training_data = GetData(TRAINING_DIR)
    test_data = GetData(TEST_DIR)

    with tf.name_scope('inputs'):
        #create the model
        x = tf.placeholder(tf.float32,
                           [Batch_SIZE, Img_depth, Img_rows, Img_cols, 1],
                           name='x_input')

        # Define loss and optimizer
        y_ = tf.placeholder(
            tf.int16, [Batch_SIZE, Img_depth, Img_rows, Img_cols, n_class],
            name='y__input')

    #define a global step
    global_step = tf.Variable(0, name="global_step")

    # Build the graph for the deep net
    network, outputs = network(x)

    dice_loss = dice_coef_loss(outputs, y_)

    with tf.name_scope('train'):
        train_step = tf.train.AdamOptimizer(1e-5).minimize(dice_loss)

    #add ops to save and restore all the variables
    saver = tf.train.Saver()

    training_summary = tf.summary.scalar("training_loss", dice_loss)
    validation_summary = tf.summary.scalar("validation_loss", dice_loss)

    #use only single CPU
    m_config = tf.ConfigProto()
    m_config.gpu_options.allow_growth = True

    with tf.Session(config=m_config) as sess:

        summary_writer = tf.summary.FileWriter("log/", sess.graph)

        sess.run(tf.global_variables_initializer(
        ))  #when continue training this model, should comment this line

        #first start to train the model, should comment these lines
        #        check_points_list = tf.train.latest_checkpoint(LOG_DIR)   #return the filename of the lastest checkpoint
        #        print(len(check_points_list))
        #        print(check_points_list)  #is the name of this checkpoint
        #        saver.restore(sess,check_points_list)
        #

        global_step_value = sess.run(global_step)
        print("Last iteration:", global_step_value)
        for i in range(global_step_value + 1, 150000 + 1):
            images, labels = training_data.next_batch(Batch_SIZE)
            feed_dict_train = {x: images, y_: labels}
            feed_dict_train.update(network.all_drop)  #enable noise layers
            train_step.run(feed_dict=feed_dict_train)

            if i % 50 == 0:
                print("iteration now:", i)
                train_loss, train_summ = sess.run(
                    [dice_loss, training_summary], feed_dict=feed_dict_train)
                summary_writer.add_summary(train_summ, i)
                print('train loss %g' % train_loss)

                images_test, labels_test = test_data.next_batch(Batch_SIZE)
                dp_dict = tl.utils.dict_to_one(
                    network.all_drop)  #disable nosie layers when testing
                feed_dict_test = {x: images_test, y_: labels_test}
                feed_dict_test.update(dp_dict)
                #                loss = dice_loss.eval(feed_dict=feed_dict)
                valid_loss, valid_summ = sess.run(
                    [dice_loss, validation_summary], feed_dict=feed_dict_test)
                summary_writer.add_summary(valid_summ, i)
                print('test loss %g' % valid_loss)
                print('----------------------------------')
            if i % 5000 == 0:
                print("iteration now:", i)

                output_image = sess.run(
                    outputs,
                    feed_dict=feed_dict_test)  #use the test next_batch
                #                output_image = outputs.eval(feed_dict=feed_dict_test)
                print(type(output_image))
                print(np.shape(output_image))
                #                output_image = np.asarray(output_image)
                #                output_image= outputs.eval(feed_dict={x:images})
                for j in range(Batch_SIZE):

                    labels_test_union = labels_test[
                        ...,
                        0] * 500 + labels_test[..., 1] * 600 + labels_test[
                            ...,
                            2] * 420 + labels_test[..., 3] * 550 + labels_test[
                                ..., 4] * 205 + labels_test[
                                    ..., 5] * 820 + labels_test[..., 6] * 850
                    input_Image = images_test[..., 0]

                    LVB = output_image[..., 0]
                    out_LVB = LVB[j, ...]
                    RVB = output_image[..., 1]
                    out_RVB = RVB[j, ...]
                    LAB = output_image[..., 2]
                    out_LAB = LAB[j, ...]
                    RAB = output_image[..., 3]
                    out_RAB = RAB[j, ...]
                    MLV = output_image[..., 4]
                    out_MLV = MLV[j, ...]
                    AA = output_image[..., 5]
                    out_AA = AA[j, ...]
                    PA = output_image[..., 6]
                    out_PA = PA[j, ...]
                    BACK = output_image[..., 7]
                    out_BACK = BACK[j, ...]
                    #将heart单独的label存储下来,查看效果
                    CreatNii_save(
                        out_LVB, save_dir,
                        "out_LVB" + str(i) + "_" + str(j) + ".nii.gz",
                        np.eye(4))
                    CreatNii_save(
                        out_RVB, save_dir,
                        "out_RVB" + str(i) + "_" + str(j) + ".nii.gz",
                        np.eye(4))
                    CreatNii_save(
                        out_LAB, save_dir,
                        "out_LAB" + str(i) + "_" + str(j) + ".nii.gz",
                        np.eye(4))
                    CreatNii_save(
                        out_RAB, save_dir,
                        "out_RAB" + str(i) + "_" + str(j) + ".nii.gz",
                        np.eye(4))
                    CreatNii_save(
                        out_MLV, save_dir,
                        "out_MLV" + str(i) + "_" + str(j) + ".nii.gz",
                        np.eye(4))
                    CreatNii_save(out_AA, save_dir,
                                  "out_AA" + str(i) + "_" + str(j) + ".nii.gz",
                                  np.eye(4))
                    CreatNii_save(out_PA, save_dir,
                                  "out_PA" + str(i) + "_" + str(j) + ".nii.gz",
                                  np.eye(4))
                    CreatNii_save(
                        out_BACK, save_dir,
                        "out_BACK" + str(i) + "_" + str(j) + ".nii.gz",
                        np.eye(4))

                    CreatNii_save(
                        input_Image[j, ...], save_dir,
                        "Input_Test_Image" + str(i) + "_" + str(j) + ".nii.gz",
                        np.eye(4))
                    CreatNii_save(
                        (labels_test_union[j,
                                           ...]).astype(np.float32), save_dir,
                        "Test_Label" + str(i) + "_" + str(j) + ".nii.gz",
                        np.eye(4))

            if i % 1000 == 0:
                print("iteration now:", i)
                #注意global_step.assign()并不会改变global_step的值,只是创造了这么一个操作,只有运行它之后,global_step才会真正被赋值
                global_step_op = global_step.assign(
                    i
                )  #this line is necessary, if not the iteration number is always 0
                print("global_step_value:", sess.run(global_step_op))
                saver.save(
                    sess, CHECKPOINT_FL, global_step=i
                )  #the "global_step" here is different from the one above
                print("================================")
                print("model is saved")