Beispiel #1
0
 def __init__(self, config_file_name):
     # 1, load configure file
     config = parse_config(config_file_name)
     self.config = config
     self.batch_size = config['testing'].get('batch_size', 5)
     self.roi_margin = config['testing'].get('roi_patch_margin', 5)
     self.__construct_networks()
     self.__creat_session_and_load_variables()
Beispiel #2
0
def test(config_file):
    # 1, load configure file
    config = parse_config(config_file)
    config_data = config['data']

    # 2, Augmentation
    dataaug = DataAug(config_data)
    dataaug.aug_data()
Beispiel #3
0
    def _load_config(self, config_file):

        print(f'Loading config file: {config_file}')

        # load the config file
        self.config = parse_config(config_file)
        self.config_data = self.config['data']
        self.config_net1 = self.config.get('network1', None)
        self.config_net3 = self.config.get('network3', None)
        self.config_net2 = self.config.get('network2', None)
        self.config_test = self.config['testing']
        self.batch_size = self.config_test.get('batch_size', 5)

        return
Beispiel #4
0
def train(config_file):
    # 1, load configuration parameters
    config = parse_config(config_file)
    config_data = config['data']
    config_net = config['network']
    config_train = config['training']

    random.seed(config_train.get('random_seed', 1))
    assert (config_data['with_ground_truth'])

    net_type = config_net['net_type']
    net_name = config_net['net_name']
    class_num = config_net['class_num']
    batch_size = config_data.get('batch_size', 5)

    # 2, construct graph
    full_data_shape = [batch_size] + config_data['data_shape']
    full_label_shape = [batch_size] + config_data['label_shape']
    x = tf.placeholder(tf.float32, shape=full_data_shape)
    w = tf.placeholder(tf.float32, shape=full_label_shape)
    y = tf.placeholder(tf.int64, shape=full_label_shape)

    w_regularizer = regularizers.l2_regularizer(config_train.get(
        'decay', 1e-7))
    b_regularizer = regularizers.l2_regularizer(config_train.get(
        'decay', 1e-7))
    net_class = NetFactory.create(net_type)
    net = net_class(num_classes=class_num,
                    w_regularizer=w_regularizer,
                    b_regularizer=b_regularizer,
                    name=net_name)
    net.set_params(config_net)
    predicty = net(x, is_training=True)
    proby = tf.nn.softmax(predicty)

    loss_func = LossFunction(n_class=class_num)
    loss = loss_func(predicty, y, weight_map=w)
    print('size of predicty:', predicty)

    # 3, initialize session and saver
    lr = config_train.get('learning_rate', 1e-3)
    opt_step = tf.train.AdamOptimizer(lr).minimize(loss)
    tf.summary.FileWriter("./graphs/" + config_net['net_name'],
                          tf.get_default_graph()).close()
Beispiel #5
0
def enhance(config_file):
    config = parse_config(config_file)
    config_data = config['data']
    data_names = config_data.get('data_names', None)
    data_root = config_data.get('data_root', None)
    file_postfix = config_data.get('file_postfix', None)

    modality_postfix = config_data.get('modality_postfix', None)
    assert (os.path.isfile(data_names))
    with open(data_names) as f:
        content = f.readlines()
    patient_names = [x.strip() for x in content]
    data_num = len(patient_names)
    print(config_data)
    #data_num=2#################
    for i in range(data_num):
        volume_list = []
        volume_name_list = []
        for mod_idx in range(len(modality_postfix)):
            volume, volume_name = load_one_volume(patient_names[i],
                                                  modality_postfix[mod_idx],
                                                  data_root, file_postfix)

            start = time.clock()

            Enhanced_volume = Enhancement_GAN(volume)

            end = time.clock()
            print("Time per image: {} ".format((end - start)))

            volume_name_enhanced = os.path.join(
                data_root, patient_names[i],
                patient_names[i][4:] + '_flairE4.nii.gz')

            save_array_as_nifty_volume(Enhanced_volume, volume_name_enhanced)

            print(volume_name_enhanced)
Beispiel #6
0
def train(config_file):
    # 1, load configuration parameters
    config = parse_config(config_file)
    config_data  = config['data']
    config_net   = config['network']
    config_train = config['training']
     
    random.seed(config_train.get('random_seed', 1))
    assert(config_data['with_ground_truth'])

    net_type    = config_net['net_type']
    net_name    = config_net['net_name']
    class_num   = config_net['class_num']
    batch_size  = config_data.get('batch_size', 5)
   
    # 2, construct graph
    full_data_shape  = [batch_size] + config_data['data_shape']
    full_label_shape = [batch_size] + config_data['label_shape']
    x = tf.placeholder(tf.float32, shape = full_data_shape)
    w = tf.placeholder(tf.float32, shape = full_label_shape)
    y = tf.placeholder(tf.int64,   shape = full_label_shape)
   
    w_regularizer = regularizers.l2_regularizer(config_train.get('decay', 1e-7))
    b_regularizer = regularizers.l2_regularizer(config_train.get('decay', 1e-7))
    net_class = NetFactory.create(net_type)
    net = net_class(num_classes = class_num,
                    w_regularizer = w_regularizer,
                    b_regularizer = b_regularizer,
                    name = net_name)
    net.set_params(config_net)
    predicty = net(x, is_training = True)
    proby    = tf.nn.softmax(predicty)
    
    loss_func = LossFunction(n_class=class_num)
    loss = loss_func(predicty, y, weight_map = w)
    print('size of predicty:',predicty)
    
    # 3, initialize session and saver
    lr = config_train.get('learning_rate', 1e-3)
    opt_step = tf.train.AdamOptimizer(lr).minimize(loss)
    sess = tf.InteractiveSession()   
    sess.run(tf.global_variables_initializer())  
    saver = tf.train.Saver()
    
    loader = DataLoader(config_data)
    train_data = loader.get_dataset('train', shuffle = True)
    batch_per_epoch = loader.get_batch_per_epoch()
    train_iterator = Iterator.from_structure(train_data.output_types,
                                             train_data.output_shapes)
    next_train_batch = train_iterator.get_next()
    train_init_op  = train_iterator.make_initializer(train_data)
    
    # 4, start to train
    loss_file = config_train['model_save_prefix'] + "_loss.txt"
    start_it  = config_train.get('start_iteration', 0)
    if( start_it > 0):
        saver.restore(sess, config_train['model_pre_trained'])
    loss_list, temp_loss_list = [], []
    for n in range(start_it, config_train['maximal_iteration']):
        if((n-start_it)%batch_per_epoch == 0):
            sess.run(train_init_op)
        one_batch = sess.run(next_train_batch)
        feed_dict = {x:one_batch['image'], w:one_batch['weight'], y:one_batch['label']}
        opt_step.run(session = sess, feed_dict=feed_dict)

        loss_value = loss.eval(feed_dict = feed_dict)
        temp_loss_list.append(loss_value)
        if((n+1)%config_train['loss_display_iteration'] == 0):
            avg_loss = np.asarray(temp_loss_list, np.float32).mean()
            t = time.strftime('%X %x %Z')
            print(t, 'iter', n+1,'loss', avg_loss)
            loss_list.append(avg_loss)
            np.savetxt(loss_file, np.asarray(loss_list))
            temp_loss_list = []
        if((n+1)%config_train['snapshot_iteration']  == 0):
            saver.save(sess, config_train['model_save_prefix']+"_{0:}.ckpt".format(n+1))
    sess.close()
Beispiel #7
0
def train(config_file):
    # 1, load configuration parameters
    config = parse_config(config_file)
    config_data = config['data']
    config_net = config['network']
    config_train = config['training']

    random.seed(config_train.get('random_seed', 1))
    assert (config_data['with_ground_truth'])

    net_type = config_net['net_type']
    net_name = config_net['net_name']
    class_num = config_net['class_num']
    batch_size = config_data.get('batch_size', 5)

    # 2, construct graph
    full_data_shape = [batch_size] + config_data['data_shape']
    full_label_shape = [batch_size] + config_data['label_shape']
    x = tf.placeholder(tf.float32, shape=full_data_shape)
    w = tf.placeholder(tf.float32, shape=full_label_shape)
    y = tf.placeholder(tf.int64, shape=full_label_shape)

    w_regularizer = regularizers.l2_regularizer(config_train.get(
        'decay', 1e-7))
    b_regularizer = regularizers.l2_regularizer(config_train.get(
        'decay', 1e-7))
    net_class = NetFactory.create(net_type)
    net = net_class(num_classes=class_num,
                    w_regularizer=w_regularizer,
                    b_regularizer=b_regularizer,
                    name=net_name)
    net.set_params(config_net)
    predicty = net(x, is_training=True)
    proby = tf.nn.softmax(predicty)

    loss_func = LossFunction(n_class=class_num)
    loss = loss_func(predicty, y, weight_map=w)
    print('size of predicty:', predicty)

    # 3, initialize session and saver
    lr = config_train.get('learning_rate', 1e-3)
    opt_step = tf.train.AdamOptimizer(lr).minimize(loss)
    sess = tf.InteractiveSession()
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()

    dataloader = DataLoader(config_data)
    dataloader.load_data()

    # 4, start to train
    loss_file = config_train['model_save_prefix'] + "_loss.txt"
    start_it = config_train.get('start_iteration', 0)
    if (start_it > 0):
        saver.restore(sess, config_train['model_pre_trained'])
    loss_list, temp_loss_list = [], []
    for n in range(start_it, config_train['maximal_iteration']):
        train_pair = dataloader.get_subimage_batch()
        tempx = train_pair['images']
        tempw = train_pair['weights']
        tempy = train_pair['labels']
        opt_step.run(session=sess, feed_dict={x: tempx, w: tempw, y: tempy})

        if (n % config_train['test_iteration'] == 0):
            batch_dice_list = []
            for step in range(config_train['test_step']):
                train_pair = dataloader.get_subimage_batch()
                tempx = train_pair['images']
                tempw = train_pair['weights']
                tempy = train_pair['labels']
                dice = loss.eval(feed_dict={x: tempx, w: tempw, y: tempy})
                batch_dice_list.append(dice)
            batch_dice = np.asarray(batch_dice_list, np.float32).mean()
            t = time.strftime('%X %x %Z')
            print(t, 'n', n, 'loss', batch_dice)
            loss_list.append(batch_dice)
            np.savetxt(loss_file, np.asarray(loss_list))

        if ((n + 1) % config_train['snapshot_iteration'] == 0):
            saver.save(
                sess,
                config_train['model_save_prefix'] + "_{0:}.ckpt".format(n + 1))
    sess.close()
Beispiel #8
0
def run(config_file):
    # construct graph
    config = parse_config(config_file)
    config_data = config['data']
    config_net1 = config.get('network1', None)
    config_net2 = config.get('network2', None)
    config_net3 = config.get('network3', None)
    config_test = config['testing']  
    batch_size   = config_test.get('batch_size', 5) 
    if(config_net1):
        net_type1    = config_net1['net_type']
        net_name1    = config_net1['net_name']
        data_shape1  = config_net1['data_shape']
        label_shape1 = config_net1['label_shape']
        data_channel1= config_net1['data_channel']
        class_num1   = config_net1['class_num']
        # construct graph for 1st network
        full_data_shape1 = [batch_size] + data_shape1 + [data_channel1]
        x1 = tf.placeholder(tf.float32, shape = full_data_shape1)          
        net_class1 = NetFactory.create(net_type1)
        net1 = net_class1(num_classes = class_num1,w_regularizer = None,
                    b_regularizer = None, name = net_name1)
        net1.set_params(config_net1)
        predicty1 = net1(x1, is_training = True)
        proby1 = tf.nn.softmax(predicty1)
    else:
        config_net1ax = config['network1ax']
        config_net1sg = config['network1sg']
        config_net1cr = config['network1cr']
        
        # construct graph for 1st network axial
        net_type1ax    = config_net1ax['net_type']
        net_name1ax    = config_net1ax['net_name']
        data_shape1ax  = config_net1ax['data_shape']
        label_shape1ax = config_net1ax['label_shape']
        data_channel1ax= config_net1ax['data_channel']
        class_num1ax   = config_net1ax['class_num']
        
        full_data_shape1ax = [batch_size] + data_shape1ax + [data_channel1ax]
        x1ax = tf.placeholder(tf.float32, shape = full_data_shape1ax)          
        net_class1ax = NetFactory.create(net_type1ax)
        net1ax = net_class1ax(num_classes = class_num1ax,w_regularizer = None,
                    b_regularizer = None, name = net_name1ax)
        net1ax.set_params(config_net1ax)
        predicty1ax = net1ax(x1ax, is_training = True)
        proby1ax = tf.nn.softmax(predicty1ax)

        # construct graph for 1st network sagittal
        net_type1sg    = config_net1sg['net_type']
        net_name1sg    = config_net1sg['net_name']
        data_shape1sg  = config_net1sg['data_shape']
        label_shape1sg = config_net1sg['label_shape']
        data_channel1sg= config_net1sg['data_channel']
        class_num1sg   = config_net1sg['class_num']
        # construct graph for 1st network
        full_data_shape1sg = [batch_size] + data_shape1sg + [data_channel1sg]
        x1sg = tf.placeholder(tf.float32, shape = full_data_shape1sg)          
        net_class1sg = NetFactory.create(net_type1sg)
        net1sg = net_class1sg(num_classes = class_num1sg,w_regularizer = None,
                    b_regularizer = None, name = net_name1sg)
        net1sg.set_params(config_net1sg)
        predicty1sg = net1sg(x1sg, is_training = True)
        proby1sg = tf.nn.softmax(predicty1sg)
            
        # construct graph for 1st network corogal
        net_type1cr    = config_net1cr['net_type']
        net_name1cr    = config_net1cr['net_name']
        data_shape1cr  = config_net1cr['data_shape']
        label_shape1cr = config_net1cr['label_shape']
        data_channel1cr= config_net1cr['data_channel']
        class_num1cr   = config_net1cr['class_num']
        # construct graph for 1st network
        full_data_shape1cr = [batch_size] + data_shape1cr + [data_channel1cr]
        x1cr = tf.placeholder(tf.float32, shape = full_data_shape1cr)          
        net_class1cr = NetFactory.create(net_type1cr)
        net1cr = net_class1cr(num_classes = class_num1cr,w_regularizer = None,
                    b_regularizer = None, name = net_name1cr)
        net1cr.set_params(config_net1cr)
        predicty1cr = net1cr(x1cr, is_training = True)
        proby1cr = tf.nn.softmax(predicty1cr)
    
    # networks for tumor core
    if(config_net2):
        net_type2    = config_net2['net_type']
        net_name2    = config_net2['net_name']
        data_shape2  = config_net2['data_shape']
        label_shape2 = config_net2['label_shape']
        data_channel2= config_net2['data_channel']
        class_num2   = config_net2['class_num']
        # construct graph for 2st network
        full_data_shape2 = [batch_size] + data_shape2 + [data_channel2]
        x2 = tf.placeholder(tf.float32, shape = full_data_shape2)          
        net_class2 = NetFactory.create(net_type2)
        net2 = net_class2(num_classes = class_num2,w_regularizer = None,
                    b_regularizer = None, name = net_name2)
        net2.set_params(config_net2)
        predicty2 = net2(x2, is_training = True)
        proby2 = tf.nn.softmax(predicty2)
    else:
        config_net2ax = config['network2ax']
        config_net2sg = config['network2sg']
        config_net2cr = config['network2cr']
        
        # construct graph for 2st network axial
        net_type2ax    = config_net2ax['net_type']
        net_name2ax    = config_net2ax['net_name']
        data_shape2ax  = config_net2ax['data_shape']
        label_shape2ax = config_net2ax['label_shape']
        data_channel2ax= config_net2ax['data_channel']
        class_num2ax   = config_net2ax['class_num']
        
        full_data_shape2ax = [batch_size] + data_shape2ax + [data_channel2ax]
        x2ax = tf.placeholder(tf.float32, shape = full_data_shape2ax)          
        net_class2ax = NetFactory.create(net_type2ax)
        net2ax = net_class2ax(num_classes = class_num2ax,w_regularizer = None,
                    b_regularizer = None, name = net_name2ax)
        net2ax.set_params(config_net2ax)
        predicty2ax = net2ax(x2ax, is_training = True)
        proby2ax = tf.nn.softmax(predicty2ax)

        # construct graph for 2st network sagittal
        net_type2sg    = config_net2sg['net_type']
        net_name2sg    = config_net2sg['net_name']
        data_shape2sg  = config_net2sg['data_shape']
        label_shape2sg = config_net2sg['label_shape']
        data_channel2sg= config_net2sg['data_channel']
        class_num2sg   = config_net2sg['class_num']
        # construct graph for 2st network
        full_data_shape2sg = [batch_size] + data_shape2sg + [data_channel2sg]
        x2sg = tf.placeholder(tf.float32, shape = full_data_shape2sg)          
        net_class2sg = NetFactory.create(net_type2sg)
        net2sg = net_class2sg(num_classes = class_num2sg,w_regularizer = None,
                    b_regularizer = None, name = net_name2sg)
        net2sg.set_params(config_net2sg)
        predicty2sg = net2sg(x2sg, is_training = True)
        proby2sg = tf.nn.softmax(predicty2sg)
            
        # construct graph for 2st network corogal
        net_type2cr    = config_net2cr['net_type']
        net_name2cr    = config_net2cr['net_name']
        data_shape2cr  = config_net2cr['data_shape']
        label_shape2cr = config_net2cr['label_shape']
        data_channel2cr= config_net2cr['data_channel']
        class_num2cr   = config_net2cr['class_num']
        # construct graph for 2st network
        full_data_shape2cr = [batch_size] + data_shape2cr + [data_channel2cr]
        x2cr = tf.placeholder(tf.float32, shape = full_data_shape2cr)          
        net_class2cr = NetFactory.create(net_type2cr)
        net2cr = net_class2cr(num_classes = class_num2cr,w_regularizer = None,
                    b_regularizer = None, name = net_name2cr)
        net2cr.set_params(config_net2cr)
        predicty2cr = net2cr(x2cr, is_training = True)
        proby2cr = tf.nn.softmax(predicty2cr)

    # for enhanced tumor
    if(config_net3):
        net_type3    = config_net3['net_type']
        net_name3    = config_net3['net_name']
        data_shape3  = config_net3['data_shape']
        label_shape3 = config_net3['label_shape']
        data_channel3= config_net3['data_channel']
        class_num3   = config_net3['class_num']
        # construct graph for 3st network
        full_data_shape3 = [batch_size] + data_shape3 + [data_channel3]
        x3 = tf.placeholder(tf.float32, shape = full_data_shape3)          
        net_class3 = NetFactory.create(net_type3)
        net3 = net_class3(num_classes = class_num3,w_regularizer = None,
                    b_regularizer = None, name = net_name3)
        net3.set_params(config_net3)
        predicty3 = net3(x3, is_training = True)
        proby3 = tf.nn.softmax(predicty3)
    else:
        config_net3ax = config['network3ax']
        config_net3sg = config['network3sg']
        config_net3cr = config['network3cr']
        
        # construct graph for 3st network axial
        net_type3ax    = config_net3ax['net_type']
        net_name3ax    = config_net3ax['net_name']
        data_shape3ax  = config_net3ax['data_shape']
        label_shape3ax = config_net3ax['label_shape']
        data_channel3ax= config_net3ax['data_channel']
        class_num3ax   = config_net3ax['class_num']
        
        full_data_shape3ax = [batch_size] + data_shape3ax + [data_channel3ax]
        x3ax = tf.placeholder(tf.float32, shape = full_data_shape3ax)          
        net_class3ax = NetFactory.create(net_type3ax)
        net3ax = net_class3ax(num_classes = class_num3ax,w_regularizer = None,
                    b_regularizer = None, name = net_name3ax)
        net3ax.set_params(config_net3ax)
        predicty3ax = net3ax(x3ax, is_training = True)
        proby3ax = tf.nn.softmax(predicty3ax)

        # construct graph for 3st network sagittal
        net_type3sg    = config_net3sg['net_type']
        net_name3sg    = config_net3sg['net_name']
        data_shape3sg  = config_net3sg['data_shape']
        label_shape3sg = config_net3sg['label_shape']
        data_channel3sg= config_net3sg['data_channel']
        class_num3sg   = config_net3sg['class_num']
        # construct graph for 3st network
        full_data_shape3sg = [batch_size] + data_shape3sg + [data_channel3sg]
        x3sg = tf.placeholder(tf.float32, shape = full_data_shape3sg)          
        net_class3sg = NetFactory.create(net_type3sg)
        net3sg = net_class3sg(num_classes = class_num3sg,w_regularizer = None,
                    b_regularizer = None, name = net_name3sg)
        net3sg.set_params(config_net3sg)
        predicty3sg = net3sg(x3sg, is_training = True)
        proby3sg = tf.nn.softmax(predicty3sg)
            
        # construct graph for 3st network corogal
        net_type3cr    = config_net3cr['net_type']
        net_name3cr    = config_net3cr['net_name']
        data_shape3cr  = config_net3cr['data_shape']
        label_shape3cr = config_net3cr['label_shape']
        data_channel3cr= config_net3cr['data_channel']
        class_num3cr   = config_net3cr['class_num']
        # construct graph for 3st network
        full_data_shape3cr = [batch_size] + data_shape3cr + [data_channel3cr]
        x3cr = tf.placeholder(tf.float32, shape = full_data_shape3cr)          
        net_class3cr = NetFactory.create(net_type3cr)
        net3cr = net_class3cr(num_classes = class_num3cr,w_regularizer = None,
                    b_regularizer = None, name = net_name3cr)
        net3cr.set_params(config_net3cr)
        predicty3cr = net3cr(x3cr, is_training = True)
        proby3cr = tf.nn.softmax(predicty3cr)
        
    all_vars = tf.global_variables()
    print('all vars', len(all_vars))
    sess = tf.InteractiveSession()   
    sess.run(tf.global_variables_initializer())  
    if(config_net1):
        net1_vars = [x for x in all_vars if x.name[0:len(net_name1) + 1]==net_name1 + '/']
        saver1 = tf.train.Saver(net1_vars)
        saver1.restore(sess, config_net1['model_file'])
    else:
#        net1ax_vars = [x for x in all_vars if x.name[0:len(net_name1ax)+1]==net_name1ax + '/']
#        saver1ax = tf.train.Saver(net1ax_vars)
#
#        saver1ax.restore(sess, config_net1sg['model_file'])
#        net1sg_vars = [x for x in all_vars if x.name[0:len(net_name1sg)+1]==net_name1sg + '/']
#        for i in range(len(net1sg_vars)):
#            copy_value = tf.assign(net1sg_vars[i], net1ax_vars[i])
#            copy_value.eval()
#        print('net1sg loaded')
#        saver1sg = tf.train.Saver(net1sg_vars)
#        saver1sg.save(sess, "model/msnet_wt32sg_20000cp.ckpt")
#        print('netsg saved')
#
#        saver1ax.restore(sess, config_net1cr['model_file'])
#        net1cr_vars = [x for x in all_vars if x.name[0:len(net_name1cr)+1]==net_name1cr + '/']
#        for i in range(len(net1cr_vars)):
#            copy_value = tf.assign(net1cr_vars[i], net1ax_vars[i])
#            copy_value.eval()
#        saver1cr = tf.train.Saver(net1cr_vars)
#        saver1cr.save(sess, "model/msnet_wt32cr_20000cp.ckpt")
#        print('net1cr saved')
#
#        saver1ax.restore(sess, config_net1ax['model_file'])

        net1ax_vars = [x for x in all_vars if x.name[0:len(net_name1ax) + 1]==net_name1ax + '/']
        saver1ax = tf.train.Saver(net1ax_vars)
        print('net1ax', len(net1ax_vars))
        saver1ax.restore(sess, config_net1ax['model_file'])
        net1sg_vars = [x for x in all_vars if x.name[0:len(net_name1sg) + 1]==net_name1sg + '/']
        print('net1sg', len(net1sg_vars))
        saver1sg = tf.train.Saver(net1sg_vars)
        saver1sg.restore(sess, config_net1sg['model_file'])     
        net1cr_vars = [x for x in all_vars if x.name[0:len(net_name1cr) + 1]==net_name1cr + '/']
        saver1cr = tf.train.Saver(net1cr_vars)
        print('net1cr', len(net1cr_vars))
        saver1cr.restore(sess, config_net1cr['model_file'])

    if(config_net2):
        net2_vars = [x for x in all_vars if x.name[0:len(net_name2) + 1]==net_name2 + '/']
        saver2 = tf.train.Saver(net2_vars)
        saver2.restore(sess, config_net2['model_file'])
    else:
#        net2ax_vars = [x for x in all_vars if x.name[0:len(net_name2ax)+1]==net_name2ax + '/']
#        saver2ax = tf.train.Saver(net2ax_vars)
#        saver2ax.restore(sess, config_net2sg['model_file'])
#        net2sg_vars = [x for x in all_vars if x.name[0:len(net_name2sg)+1]==net_name2sg + '/']
#        for i in range(len(net2sg_vars)):
#            copy_value = tf.assign(net2sg_vars[i], net2ax_vars[i])
#            copy_value.eval()
#        print('net2sg loaded')
#        saver2sg = tf.train.Saver(net2sg_vars)
#        saver2sg.save(sess, "model/msnet_tc32sg_15000cp.ckpt")
#        print('net2sg saved')
#
#        saver2ax.restore(sess, config_net2cr['model_file'])
#        net2cr_vars = [x for x in all_vars if x.name[0:len(net_name2cr)+1]==net_name2cr + '/']
#        for i in range(len(net2cr_vars)):
#            copy_value = tf.assign(net2cr_vars[i], net2ax_vars[i])
#            copy_value.eval()
#        saver2cr = tf.train.Saver(net2cr_vars)
#        saver2cr.save(sess, "model/msnet_tc32cr_10000cp.ckpt")
#        print('net2cr saved')
#        return
#        saver2ax.restore(sess, config_net2ax['model_file'])

        net2ax_vars = [x for x in all_vars if x.name[0:len(net_name2ax)+1]==net_name2ax + '/']
        saver2ax = tf.train.Saver(net2ax_vars)
        saver2ax.restore(sess, config_net2ax['model_file'])
        net2sg_vars = [x for x in all_vars if x.name[0:len(net_name2sg)+1]==net_name2sg + '/']
        saver2sg = tf.train.Saver(net2sg_vars)
        saver2sg.restore(sess, config_net2sg['model_file'])     
        net2cr_vars = [x for x in all_vars if x.name[0:len(net_name2cr)+1]==net_name2cr + '/']
        saver2cr = tf.train.Saver(net2cr_vars)
        saver2cr.restore(sess, config_net2cr['model_file'])     

    if(config_net3):
        net3_vars = [x for x in all_vars if x.name[0:len(net_name3) + 1]==net_name3 + '/']
        saver3 = tf.train.Saver(net3_vars)
        saver3.restore(sess, config_net3['model_file'])
    else:
#        net3ax_vars = [x for x in all_vars if x.name[0:len(net_name3ax)+1]==net_name3ax + '/']
#        saver3ax = tf.train.Saver(net3ax_vars)
#        saver3ax.restore(sess, config_net3sg['model_file'])
#        net3sg_vars = [x for x in all_vars if x.name[0:len(net_name3sg)+1]==net_name3sg + '/']
#        for i in range(len(net3sg_vars)):
#            copy_value = tf.assign(net3sg_vars[i], net3ax_vars[i])
#            copy_value.eval()
#        print('net3sg loaded')
#        saver3sg = tf.train.Saver(net3sg_vars)
#        saver3sg.save(sess, "model/msnet_en32sg_20000cp.ckpt")
#        print('net3sg saved')
#
#        saver3ax.restore(sess, config_net3cr['model_file'])
#        net3cr_vars = [x for x in all_vars if x.name[0:len(net_name3cr)+1]==net_name3cr + '/']
#        for i in range(len(net3cr_vars)):
#            copy_value = tf.assign(net3cr_vars[i], net3ax_vars[i])
#            copy_value.eval()
#        saver3cr = tf.train.Saver(net3cr_vars)
#        saver3cr.save(sess, "model/msnet_en32cr_20000cp.ckpt")
#        print('net3cr saved')
#        
#        saver3ax.restore(sess, config_net3ax['model_file'])
        net3ax_vars = [x for x in all_vars if x.name[0:len(net_name3ax) + 1]==net_name3ax+ '/']
        saver3ax = tf.train.Saver(net3ax_vars)
        saver3ax.restore(sess, config_net3ax['model_file'])
        net3sg_vars = [x for x in all_vars if x.name[0:len(net_name3sg) + 1]==net_name3sg+ '/']
        saver3sg = tf.train.Saver(net3sg_vars)
        saver3sg.restore(sess, config_net3sg['model_file'])     
        net3cr_vars = [x for x in all_vars if x.name[0:len(net_name3cr) + 1]==net_name3cr+ '/']
        saver3cr = tf.train.Saver(net3cr_vars)
        saver3cr.restore(sess, config_net3cr['model_file'])     

    loader = DataLoader()
    loader.set_params(config_data)
    loader.load_data()
  
    # start to test  
    image_num = loader.get_total_image_number()
    test_slice_direction = config_test.get('test_slice_direction', 'all')
    save_folder = config_test['save_folder']
    test_time = []
    struct = ndimage.generate_binary_structure(3, 2)
    margin = config_test.get('roi_patch_margin', 5)
    for i in range(image_num):
        # test of 1st network
        t0 = time.time()
        [imgs, weight, temp_name] = loader.get_image_data_with_name(i)
        groi = get_roi(weight > 0, margin)
        temp_imgs = [x[np.ix_(range(groi[0], groi[1]), range(groi[2], groi[3]), range(groi[4], groi[5]))] \
                      for x in imgs]
        temp_weight = weight[np.ix_(range(groi[0], groi[1]), range(groi[2], groi[3]), range(groi[4], groi[5]))]

        if(config_net1):
            data_shapes = [data_shape1, data_shape1, data_shape1]
            label_shapes = [label_shape1, label_shape1, label_shape1]
            nets = [net1, net1, net1]
            outputs = [proby1, proby1, proby1]
            inputs =  [x1, x1, x1]
            data_channel = data_channel1
            class_num = class_num1
        else:
            data_shapes = [data_shape1ax, data_shape1sg, data_shape1cr]
            label_shapes = [label_shape1ax, label_shape1sg, label_shape1cr]
            nets = [net1ax, net1sg, net1cr]
            outputs = [proby1ax, proby1sg, proby1cr]
            inputs =  [x1ax, x1sg, x1cr]
            data_channel = data_channel1ax
            class_num = class_num1ax
        prob1 = test_one_image_three_nets_adaptive_shape(temp_imgs, data_shapes, label_shapes, data_channel, class_num,
                   batch_size, sess, nets, outputs, inputs, shape_mode = 0)
        pred1 =  np.asarray(np.argmax(prob1, axis = 3), np.uint16)
        pred1 = pred1 * temp_weight

#        out_label = pred1 * temp_weight
#        label_convert_source = config_test.get('label_convert_source', None)
#        label_convert_target = config_test.get('label_convert_target', None)
#        if(label_convert_source and label_convert_target):
#            assert(len(label_convert_source) == len(label_convert_target))
#            out_label = convert_label(out_label, label_convert_source, label_convert_target)
        # test of 2nd network
        wt_threshold = 2000
        if(pred1.sum() == 0):
            print('net1 output is null', temp_name)
            roi2 = get_roi(temp_imgs[0] > 0, margin)
        else:
            pred1_lc = ndimage.morphology.binary_closing(pred1, structure = struct)
            pred1_lc = get_largest_two_component(pred1_lc, True, wt_threshold)
            roi2 = get_roi(pred1_lc, margin)
        sub_imgs = [x[np.ix_(range(roi2[0], roi2[1]), range(roi2[2], roi2[3]), range(roi2[4], roi2[5]))] \
                      for x in temp_imgs]
        sub_weight = temp_weight[np.ix_(range(roi2[0], roi2[1]), range(roi2[2], roi2[3]), range(roi2[4], roi2[5]))]
        if(config_net2):
            data_shapes = [data_shape2, data_shape2, data_shape2]
            label_shapes = [label_shape2, label_shape2, label_shape2]
            nets = [net2, net2, net2]
            outputs = [proby2, proby2, proby2]
            inputs =  [x2, x2, x2]
            data_channel = data_channel2
            class_num = class_num2
        else:
            data_shapes = [data_shape2ax, data_shape2sg, data_shape2cr]
            label_shapes = [label_shape2ax, label_shape2sg, label_shape2cr]
            nets = [net2ax, net2sg, net2cr]
            outputs = [proby2ax, proby2sg, proby2cr]
            inputs =  [x2ax, x2sg, x2cr]
            data_channel = data_channel2ax
            class_num = class_num2ax
        prob2 = test_one_image_three_nets_adaptive_3dshape(sub_imgs, data_shapes, label_shapes, data_channel, class_num, sess, nets, outputs, inputs)
#        prob2 = test_one_image_three_nets_adaptive_shape(sub_imgs, data_shapes, label_shapes, data_channel, class_num,  batch_size, sess, nets, outputs, inputs, shape_mode = 1)
        pred2 = np.asarray(np.argmax(prob2, axis = 3), np.uint16)
        pred2 = pred2 * sub_weight
         
        # test of 3rd network
        if(pred2.sum() == 0):
            [roid, roih, roiw] = sub_imgs[0].shape
            roi3 = [0, roid, 0, roih, 0, roiw]
            subsub_imgs = sub_imgs
            subsub_weight = sub_weight
        else:
            pred2_lc = ndimage.morphology.binary_closing(pred2, structure = struct)
            pred2_lc = get_largest_two_component(pred2_lc)
            roi3 = get_roi(pred2_lc, margin)
            subsub_imgs = [x[np.ix_(range(roi3[0], roi3[1]), range(roi3[2], roi3[3]), range(roi3[4], roi3[5]))] \
                      for x in sub_imgs]
            subsub_weight = sub_weight[np.ix_(range(roi3[0], roi3[1]), range(roi3[2], roi3[3]), range(roi3[4], roi3[5]))] 
        
        if(config_net3):
            data_shapes = [data_shape3, data_shape3, data_shape3]
            label_shapes = [label_shape3, label_shape3, label_shape3]
            nets = [net3, net3, net3]
            outputs = [proby3, proby3, proby3]
            inputs =  [x3, x3, x3]
            data_channel = data_channel3
            class_num = class_num3
        else:
            data_shapes = [data_shape3ax, data_shape3sg, data_shape3cr]
            label_shapes = [label_shape3ax, label_shape3sg, label_shape3cr]
            nets = [net3ax, net3sg, net3cr]
            outputs = [proby3ax, proby3sg, proby3cr]
            inputs =  [x3ax, x3sg, x3cr]
            data_channel = data_channel3ax
            class_num = class_num3ax
        prob3 = test_one_image_three_nets_adaptive_shape(subsub_imgs, data_shapes, label_shapes, data_channel, class_num,
                   batch_size, sess, nets, outputs, inputs, shape_mode = 1)
        
        pred3 = np.asarray(np.argmax(prob3, axis = 3), np.uint16)
        pred3 = pred3 * subsub_weight
         
        # fuse results at 3 levels
        # convert subsub_label to full size (non-enhanced)
        label3_roi = np.zeros_like(pred2)
        label3_roi[np.ix_(range(roi3[0], roi3[1]), range(roi3[2], roi3[3]), range(roi3[4], roi3[5]))] = pred3
        label3 = np.zeros_like(pred1)
        label3[np.ix_(range(roi2[0], roi2[1]), range(roi2[2], roi2[3]), range(roi2[4], roi2[5]))] = label3_roi

        # convert sub_label to full size (tumor core)
        label2 = np.zeros_like(pred1)
        label2[np.ix_(range(roi2[0], roi2[1]), range(roi2[2], roi2[3]), range(roi2[4], roi2[5]))] = pred2

        
        # fuse the results
        label1_mask = (pred1 + label2 + label3) > 0
        label1_mask = ndimage.morphology.binary_closing(label1_mask, structure = struct)
        label1_mask = get_largest_two_component(label1_mask, False, wt_threshold)
        label1 = pred1 * label1_mask
        
        label2_3_mask = (label2 + label3) > 0
        label2_3_mask = label2_3_mask * label1_mask
        label2_3_mask = ndimage.morphology.binary_closing(label2_3_mask, structure = struct)
        label2_3_mask = remove_external_core(label1, label2_3_mask)
        if(label2_3_mask.sum() > 0):
            label2_3_mask = get_largest_two_component(label2_3_mask)
        label1 = (label1 + label2_3_mask) > 0
        label2 = label2_3_mask
        label3 = label2 * label3
        vox_3  = label3.sum() 
        if(0 < vox_3 and vox_3 < 30):
            print('ignored voxel number ', vox_3, flush = True)
            label3 = np.zeros_like(label2)
            
        out_label = label1 * 2 
        out_label[label2>0] = 1
        out_label[label3>0] = 4
        out_label = np.asarray(out_label, np.int16)

        test_time.append(time.time() - t0)
        final_label = np.zeros_like(weight, np.int16)
        final_label[np.ix_(range(groi[0], groi[1]), range(groi[2], groi[3]), range(groi[4], groi[5]))] = out_label

        save_array_as_nifty_volume(final_label, save_folder+"/{0:}.nii.gz".format(temp_name))
        print(temp_name, flush = True)
    test_time = np.asarray(test_time)
    print('test time', test_time.mean(), flush= True)
    np.savetxt(save_folder + '/test_time.txt', test_time)
    sess.close()
Beispiel #9
0
            result_folder = "{0:}/{1:}".format(result_root, item)
        seg_files = os.listdir(result_folder)
        for seg_file in seg_files:
            if ("nii.gz" in seg_file):
                full_seg_name = "{0:}/{1:}".format(result_folder, seg_file)
                img_obj = nibabel.load(full_seg_name)
                img_data = img_obj.get_data()
                seg_data.append(img_data)

        seg_data = np.asarray(seg_data)
        vote_data = stats.mode(seg_data, axis=0)[0][0]

        # post process
        vote_data = brats_post_process(vote_data)
        output_img = nibabel.nifti1.Nifti1Image(vote_data, img_obj.affine,
                                                img_obj.header, img_obj.extra,
                                                img_obj.file_map)

        save_filename = "{0:}.nii.gz".format(result_folder)
        nibabel.save(output_img, save_filename)


if __name__ == "__main__":
    if (len(sys.argv) != 2):
        print('Number of arguments should be 2. e.g.')
        print('    python get_vote_result.py vote_result_cfg.txt')
        exit()
    config_file = str(sys.argv[1])
    config = parse_config(config_file)
    print(config)
Beispiel #10
0
def seg(config_file):
    # 1, load configuration parameters
    print('1.Load parameters')
    config = parse_config(config_file)
    config_data = config['data']  # config of data,e.g. data_shape,batch_size.
    config_net = config[
        'network']  # config of net, e.g. net_name,base_feature_name,class_num.
    config_train = config['training']
    random.seed(config_train.get('random_seed', 1))
    output_feature = config_data.get('output_feature', False)
    net_type = config_net['net_type']
    class_num = config_net['class_num']
    save = False
    show = False
    cal_dice = True
    cal_assd = False

    # 2, load data
    print('2.Load data')
    Datamode = ['valid']

    # 3. creat model
    print('3.Creat model')
    # dice_eval = TestDiceLoss(class_num)
    net_class = NetFactory.create(net_type)
    net = net_class(inc=config_net.get('input_channel', 1),
                    n_classes=class_num,
                    base_chns=config_net.get('base_feature_number', 16),
                    droprate=config_net.get('drop_rate', 0.2),
                    norm='in',
                    depth=False,
                    dilation=1)

    net = torch.nn.DataParallel(net, device_ids=[0, 1]).cuda()
    if config_train['load_weight']:
        weight = torch.load(config_train['model_path'],
                            map_location=lambda storage, loc: storage)
        net.load_state_dict(weight)
    print(torch.cuda.is_available())

    # 4, start to seg
    print('''start to seg ''')
    net.eval()
    for mode in Datamode:
        Data = LYC_dataset(config_data, mode)
        patient_number = len(
            os.listdir(os.path.join(config_data['data_root'], mode)))
        with torch.no_grad():
            t_array = np.zeros(patient_number)
            dice_array = np.zeros([patient_number, class_num])
            assd_array = np.zeros([patient_number, class_num])
            for patient_order in range(patient_number):
                t1 = time.time()
                valid_pair, patient_path = Data.get_list_img(patient_order)
                clip_number = len(valid_pair['images'])  # 裁剪块数
                clip_height = config_data['test_data_shape'][0]  # 裁剪图像的高度
                total_labels = valid_pair['labels'].cuda()
                predic_size = torch.Size([1, class_num
                                          ]) + total_labels.size()[1::]
                totalpredic = torch.zeros(predic_size).cuda()  # 完整预测
                if output_feature:
                    outfeature_size = torch.Size([
                        1, 2 * config_net.get('base_feature_number')
                    ]) + total_labels.size()[1::]
                    totalfeature = torch.zeros(outfeature_size).cuda()
                for i in range(clip_number):
                    tempx = valid_pair['images'][i].cuda()
                    if output_feature:
                        pred, outfeature = net(tempx)
                    else:
                        pred = net(tempx)
                    if i < clip_number - 1:
                        totalpredic[:, :, i * clip_height:(i + 1) *
                                    clip_height] = pred
                    else:
                        totalpredic[:, :, -clip_height::] = pred
                    if output_feature:
                        if i < clip_number - 1:
                            totalfeature[:, :, i * clip_height:(i + 1) *
                                         clip_height] = outfeature
                        else:
                            totalfeature[:, :, -clip_height::] = outfeature

                # torchdice = dice_eval(totalpredic, total_labels)
                # print('torch dice:', torchdice)
                totalpredic = torch.max(totalpredic, 1)[1].squeeze()
                totalpredic = np.uint8(
                    totalpredic.cpu().data.numpy().squeeze())
                totallabel = np.uint8(
                    total_labels.cpu().data.numpy().squeeze())
                if output_feature:
                    totalfeature = totalpredic.cpu().data.numpy().squeeze()
                t2 = time.time()
                t = t2 - t1
                t_array[patient_order] = t

                one_hot_label = one_hot(totallabel, class_num)
                one_hot_predic = one_hot(totalpredic, class_num)

                if cal_dice:
                    Dice = np.zeros(class_num)
                    for i in range(class_num):
                        Dice[i] = dc(one_hot_predic[i], one_hot_label[i])
                    dice_array[patient_order] = Dice
                    print('patient order', patient_order, ' dice:', Dice)

                if cal_assd:
                    Assd = np.zeros(class_num)
                    for i in range(class_num):
                        Assd[i] = assd(one_hot_predic[i], one_hot_label[i], 1)
                    assd_array[patient_order] = Assd

                if show:
                    for i in np.arange(0, totalpredic.shape[0], 2):
                        f, plots = plt.subplots(1, 2)
                        plots[0].imshow(totalpredic[i])
                        plots[1].imshow(totallabel[i])
                        #plots[2].imshow(oriseg[i])
                        # plots[1, 0].imshow(totalfeature[0, i])
                        # plots[1, 1].imshow(totalfeature[5, i])
                        plt.show()
                if save:
                    if output_feature:
                        np.save(patient_path + '/Feature.npy', totalfeature)
                    #np.save(patient_path + '/Seg_2.npy', totalpredic)
                    save_array_as_nifty_volume(totalpredic,
                                               patient_path + '/Seg.nii.gz')
                    # np.savetxt(patient_path+'/Dice.npy', Dice.squeeze())
                    # np.savetxt(patient_path+'/Assd.npy', Assd.squeeze())

        if cal_dice:
            dice_array[:, 0] = np.mean(dice_array[:, 1::], 1)
            dice_mean = np.mean(dice_array, 0)
            dice_std = np.std(dice_array, 0)
            print('{0:} mode: mean dice:{1:}, std of dice:{2:}'.format(
                mode, dice_mean, dice_std))

        if cal_assd:
            assd_array[:, 0] = np.mean(assd_array[:, 1::], 1)
            assd_mean = np.mean(assd_array, 0)
            assd_std = np.std(assd_array, 0)
            print('{0:} mode: mean assd:{1:}, std of assd:{2:}'.format(
                mode, assd_mean, assd_std))

        t_mean = [t_array.mean()]
        t_std = [t_array.std()]
        print('{0:} mode: mean time:{1:}, std of time:{2:}'.format(
            mode, t_mean, t_std))
Beispiel #11
0
def train(config_file):
    # 1, load configuration parameters
    print('1.Load parameters')
    config = parse_config(config_file)
    config_data = config['data']  # data config, like data_shape,batch_size,
    config_net = config[
        'network']  # net config, like net_name,base_feature_name,class_num
    config_train = config['training']

    random.seed(config_train.get('random_seed', 1))

    valid_patient_number = len(
        os.listdir(config_data['data_root'] + '/' + 'valid'))
    net_type = config_net['net_type']
    class_num = config_net['class_num']
    batch_size = config_data.get('batch_size', 4)
    lr = config_train.get('learning_rate', 1e-3)
    best_dice = config_train.get('best_dice', 0.5)

    # 2, load data
    print('2.Load data')
    trainData = LYC_dataset(config_data, 'train')
    validData = LYC_dataset(config_data, 'valid')

    # 3. creat model
    print('3.Creat model')
    net_class = NetFactory.create(net_type)
    net = net_class(inc=config_net.get('input_channel', 1),
                    n_classes=class_num,
                    base_chns=config_net.get('base_feature_number', 16),
                    droprate=config_net.get('drop_rate', 0.2),
                    norm='in',
                    depth=config_net.get('depth', False),
                    dilation=config_net.get('dilation', 1),
                    separate_direction='axial')
    net = torch.nn.DataParallel(net, device_ids=[0, 1]).cuda()
    if config_train['load_weight']:
        weight = torch.load(config_train['model_path'],
                            map_location=lambda storage, loc: storage)
        net.load_state_dict(weight)

    show_param(net)

    dice_eval = TestDiceLoss(n_class=class_num)
    loss_func = AttentionExpDiceLoss(n_class=class_num, alpha=0.5)
    show_loss = loss_visualize(class_num)

    Adamoptimizer = optim.Adam(net.parameters(),
                               lr=lr,
                               weight_decay=config_train.get('decay', 1e-7))
    Adamscheduler = torch.optim.lr_scheduler.StepLR(Adamoptimizer,
                                                    step_size=10,
                                                    gamma=0.9)

    # 4, start to train
    print('4.Start to train')
    dice_file = config_train['model_save_prefix'] + "_dice.txt"
    start_it = config_train.get('start_iteration', 0)
    dice_save = np.zeros([config_train['maximal_epoch'], 2 + class_num])
    for n in range(start_it, config_train['maximal_epoch']):
        train_loss_list, train_dice_list = np.zeros(
            config_train['train_step'] //
            config_train['print_step']), np.zeros([
                config_train['train_step'] // config_train['print_step'],
                class_num
            ])
        valid_loss_list, valid_dice_list = np.zeros(
            valid_patient_number), np.zeros([valid_patient_number, class_num])

        optimizer = Adamoptimizer

        net.train()
        print('###train###\n')
        for step in range(config_train['train_step']):
            train_pair = trainData.get_subimage_batch()
            tempx = torch.FloatTensor(train_pair['images']).cuda()
            tempy = torch.FloatTensor(train_pair['labels']).cuda()
            # soft_tempy = get_soft_label(tempy.unsqueeze(1), class_num)
            predic = net(tempx)
            train_loss = loss_func(predic, tempy)
            optimizer.zero_grad()
            train_loss.backward()
            # torch.nn.utils.clip_grad_norm(net.parameters(), 10)
            optimizer.step()
            if step % config_train['print_step'] == 0:
                train_loss = train_loss.cpu().data.numpy()
                train_loss_list[step //
                                config_train['print_step']] = train_loss
                train_dice = dice_eval(predic, tempy)
                train_dice = train_dice.cpu().data.numpy()
                train_dice_list[step //
                                config_train['print_step']] = train_dice
                print('train loss:', train_loss, ' train dice:', train_dice)
        Adamscheduler.step()

        print('###test###\n')
        with torch.no_grad():
            net.eval()
            for patient_order in range(valid_patient_number):
                valid_pair, patient_path = validData.get_list_img(
                    patient_order)
                clip_number = len(valid_pair['images'])
                clip_height = config_data['test_data_shape'][0]
                total_labels = valid_pair['labels'].cuda()
                predic_size = torch.Size([1, class_num
                                          ]) + total_labels.size()[1::]
                totalpredic = torch.zeros(predic_size).cuda()

                for i in range(clip_number):
                    tempx = valid_pair['images'][i].cuda()
                    pred = net(tempx)
                    # pred[:, 0][tempx[:, 0] <= 0.0001] = 1
                    if i < clip_number - 1:
                        totalpredic[:, :, i * clip_height:(i + 1) *
                                    clip_height] = pred
                    else:
                        totalpredic[:, :, -clip_height::] = pred

                valid_dice = dice_eval(totalpredic, total_labels,
                                       show=True).cpu().data.numpy()
                valid_dice_list[patient_order] = valid_dice
                print(' valid dice:', valid_dice)

        batch_dice = [
            valid_dice_list.mean(axis=0),
            train_dice_list.mean(axis=0)
        ]
        t = time.strftime('%X %x %Z')
        print(t, 'n', n, '\ndice:\n', batch_dice)
        show_loss.plot_loss(n, batch_dice)
        train_dice_mean = np.asarray([batch_dice[1][1::].mean(axis=0)])
        valid_dice_classes = batch_dice[0][1::]
        valid_dice_mean = np.asarray([valid_dice_classes.mean(axis=0)])
        batch_dice = np.append(np.append(train_dice_mean, valid_dice_mean),
                               valid_dice_classes)
        dice_save[n] = np.append(n, batch_dice)

        if batch_dice[1] > best_dice:
            best_dice = batch_dice[1]
            torch.save(
                net.state_dict(), config_train['model_save_prefix'] +
                "_{0:}.pkl".format(batch_dice[1]))
Beispiel #12
0
    def test(self):
        # 1, load configure file
        config = parse_config(self.config)
        config_data = config['data']
        #assign the listened acc to the model
        #the input dir is /gpfs/data/luilab/BRATS/data/incoming/acc#/brain
        config_data['data_names'] = self.acc + '/nifti'
        config_net1 = config.get('network1', None)
        config_net2 = config.get('network2', None)
        config_net3 = config.get('network3', None)
        config_test = config['testing']
        batch_size = config_test.get('batch_size', 5)

        # 2.1, network for whole tumor
        if (config_net1):
            net_type1 = config_net1['net_type']
            net_name1 = config_net1['net_name']
            data_shape1 = config_net1['data_shape']
            label_shape1 = config_net1['label_shape']
            class_num1 = config_net1['class_num']

            # construct graph for 1st network
            full_data_shape1 = [batch_size] + data_shape1
            x1 = tf.placeholder(tf.float32, shape=full_data_shape1)
            net_class1 = NetFactory.create(net_type1)
            net1 = net_class1(num_classes=class_num1,
                              w_regularizer=None,
                              b_regularizer=None,
                              name=net_name1)
            net1.set_params(config_net1)
            predicty1 = net1(x1, is_training=True)
            proby1 = tf.nn.softmax(predicty1)
        else:
            config_net1ax = config['network1ax']
            config_net1sg = config['network1sg']
            config_net1cr = config['network1cr']

            # construct graph for 1st network axial
            net_type1ax = config_net1ax['net_type']
            net_name1ax = config_net1ax['net_name']
            data_shape1ax = config_net1ax['data_shape']
            label_shape1ax = config_net1ax['label_shape']
            class_num1ax = config_net1ax['class_num']

            full_data_shape1ax = [batch_size] + data_shape1ax
            x1ax = tf.placeholder(tf.float32, shape=full_data_shape1ax)
            net_class1ax = NetFactory.create(net_type1ax)
            net1ax = net_class1ax(num_classes=class_num1ax,
                                  w_regularizer=None,
                                  b_regularizer=None,
                                  name=net_name1ax)
            net1ax.set_params(config_net1ax)
            predicty1ax = net1ax(x1ax, is_training=True)
            proby1ax = tf.nn.softmax(predicty1ax)

            # construct graph for 1st network sagittal
            net_type1sg = config_net1sg['net_type']
            net_name1sg = config_net1sg['net_name']
            data_shape1sg = config_net1sg['data_shape']
            label_shape1sg = config_net1sg['label_shape']
            class_num1sg = config_net1sg['class_num']

            full_data_shape1sg = [batch_size] + data_shape1sg
            x1sg = tf.placeholder(tf.float32, shape=full_data_shape1sg)
            net_class1sg = NetFactory.create(net_type1sg)
            net1sg = net_class1sg(num_classes=class_num1sg,
                                  w_regularizer=None,
                                  b_regularizer=None,
                                  name=net_name1sg)
            net1sg.set_params(config_net1sg)
            predicty1sg = net1sg(x1sg, is_training=True)
            proby1sg = tf.nn.softmax(predicty1sg)

            # construct graph for 1st network corogal
            net_type1cr = config_net1cr['net_type']
            net_name1cr = config_net1cr['net_name']
            data_shape1cr = config_net1cr['data_shape']
            label_shape1cr = config_net1cr['label_shape']
            class_num1cr = config_net1cr['class_num']

            full_data_shape1cr = [batch_size] + data_shape1cr
            x1cr = tf.placeholder(tf.float32, shape=full_data_shape1cr)
            net_class1cr = NetFactory.create(net_type1cr)
            net1cr = net_class1cr(num_classes=class_num1cr,
                                  w_regularizer=None,
                                  b_regularizer=None,
                                  name=net_name1cr)
            net1cr.set_params(config_net1cr)
            predicty1cr = net1cr(x1cr, is_training=True)
            proby1cr = tf.nn.softmax(predicty1cr)

        if (config_test.get('whole_tumor_only', False) is False):  #改动1 !
            # 2.2, networks for tumor core
            if (config_net2):
                net_type2 = config_net2['net_type']
                net_name2 = config_net2['net_name']
                data_shape2 = config_net2['data_shape']
                label_shape2 = config_net2['label_shape']
                class_num2 = config_net2['class_num']

                # construct graph for 2st network
                full_data_shape2 = [batch_size] + data_shape2
                x2 = tf.placeholder(tf.float32, shape=full_data_shape2)
                net_class2 = NetFactory.create(net_type2)
                net2 = net_class2(num_classes=class_num2,
                                  w_regularizer=None,
                                  b_regularizer=None,
                                  name=net_name2)
                net2.set_params(config_net2)
                predicty2 = net2(x2, is_training=True)
                proby2 = tf.nn.softmax(predicty2)
            else:
                config_net2ax = config['network2ax']
                config_net2sg = config['network2sg']
                config_net2cr = config['network2cr']

                # construct graph for 2st network axial
                net_type2ax = config_net2ax['net_type']
                net_name2ax = config_net2ax['net_name']
                data_shape2ax = config_net2ax['data_shape']
                label_shape2ax = config_net2ax['label_shape']
                class_num2ax = config_net2ax['class_num']

                full_data_shape2ax = [batch_size] + data_shape2ax
                x2ax = tf.placeholder(tf.float32, shape=full_data_shape2ax)
                net_class2ax = NetFactory.create(net_type2ax)
                net2ax = net_class2ax(num_classes=class_num2ax,
                                      w_regularizer=None,
                                      b_regularizer=None,
                                      name=net_name2ax)
                net2ax.set_params(config_net2ax)
                predicty2ax = net2ax(x2ax, is_training=True)
                proby2ax = tf.nn.softmax(predicty2ax)

                # construct graph for 2st network sagittal
                net_type2sg = config_net2sg['net_type']
                net_name2sg = config_net2sg['net_name']
                data_shape2sg = config_net2sg['data_shape']
                label_shape2sg = config_net2sg['label_shape']
                class_num2sg = config_net2sg['class_num']

                full_data_shape2sg = [batch_size] + data_shape2sg
                x2sg = tf.placeholder(tf.float32, shape=full_data_shape2sg)
                net_class2sg = NetFactory.create(net_type2sg)
                net2sg = net_class2sg(num_classes=class_num2sg,
                                      w_regularizer=None,
                                      b_regularizer=None,
                                      name=net_name2sg)
                net2sg.set_params(config_net2sg)
                predicty2sg = net2sg(x2sg, is_training=True)
                proby2sg = tf.nn.softmax(predicty2sg)

                # construct graph for 2st network corogal
                net_type2cr = config_net2cr['net_type']
                net_name2cr = config_net2cr['net_name']
                data_shape2cr = config_net2cr['data_shape']
                label_shape2cr = config_net2cr['label_shape']
                class_num2cr = config_net2cr['class_num']

                full_data_shape2cr = [batch_size] + data_shape2cr
                x2cr = tf.placeholder(tf.float32, shape=full_data_shape2cr)
                net_class2cr = NetFactory.create(net_type2cr)
                net2cr = net_class2cr(num_classes=class_num2cr,
                                      w_regularizer=None,
                                      b_regularizer=None,
                                      name=net_name2cr)
                net2cr.set_params(config_net2cr)
                predicty2cr = net2cr(x2cr, is_training=True)
                proby2cr = tf.nn.softmax(predicty2cr)

            # 2.3, networks for enhanced tumor
            if (config_net3):
                net_type3 = config_net3['net_type']
                net_name3 = config_net3['net_name']
                data_shape3 = config_net3['data_shape']
                label_shape3 = config_net3['label_shape']
                class_num3 = config_net3['class_num']

                # construct graph for 3st network
                full_data_shape3 = [batch_size] + data_shape3
                x3 = tf.placeholder(tf.float32, shape=full_data_shape3)
                net_class3 = NetFactory.create(net_type3)
                net3 = net_class3(num_classes=class_num3,
                                  w_regularizer=None,
                                  b_regularizer=None,
                                  name=net_name3)
                net3.set_params(config_net3)
                predicty3 = net3(x3, is_training=True)
                proby3 = tf.nn.softmax(predicty3)
            else:
                config_net3ax = config['network3ax']
                config_net3sg = config['network3sg']
                config_net3cr = config['network3cr']

                # construct graph for 3st network axial
                net_type3ax = config_net3ax['net_type']
                net_name3ax = config_net3ax['net_name']
                data_shape3ax = config_net3ax['data_shape']
                label_shape3ax = config_net3ax['label_shape']
                class_num3ax = config_net3ax['class_num']

                full_data_shape3ax = [batch_size] + data_shape3ax
                x3ax = tf.placeholder(tf.float32, shape=full_data_shape3ax)
                net_class3ax = NetFactory.create(net_type3ax)
                net3ax = net_class3ax(num_classes=class_num3ax,
                                      w_regularizer=None,
                                      b_regularizer=None,
                                      name=net_name3ax)
                net3ax.set_params(config_net3ax)
                predicty3ax = net3ax(x3ax, is_training=True)
                proby3ax = tf.nn.softmax(predicty3ax)

                # construct graph for 3st network sagittal
                net_type3sg = config_net3sg['net_type']
                net_name3sg = config_net3sg['net_name']
                data_shape3sg = config_net3sg['data_shape']
                label_shape3sg = config_net3sg['label_shape']
                class_num3sg = config_net3sg['class_num']
                # construct graph for 3st network
                full_data_shape3sg = [batch_size] + data_shape3sg
                x3sg = tf.placeholder(tf.float32, shape=full_data_shape3sg)
                net_class3sg = NetFactory.create(net_type3sg)
                net3sg = net_class3sg(num_classes=class_num3sg,
                                      w_regularizer=None,
                                      b_regularizer=None,
                                      name=net_name3sg)
                net3sg.set_params(config_net3sg)
                predicty3sg = net3sg(x3sg, is_training=True)
                proby3sg = tf.nn.softmax(predicty3sg)

                # construct graph for 3st network corogal
                net_type3cr = config_net3cr['net_type']
                net_name3cr = config_net3cr['net_name']
                data_shape3cr = config_net3cr['data_shape']
                label_shape3cr = config_net3cr['label_shape']
                class_num3cr = config_net3cr['class_num']
                # construct graph for 3st network
                full_data_shape3cr = [batch_size] + data_shape3cr
                x3cr = tf.placeholder(tf.float32, shape=full_data_shape3cr)
                net_class3cr = NetFactory.create(net_type3cr)
                net3cr = net_class3cr(num_classes=class_num3cr,
                                      w_regularizer=None,
                                      b_regularizer=None,
                                      name=net_name3cr)
                net3cr.set_params(config_net3cr)
                predicty3cr = net3cr(x3cr, is_training=True)
                proby3cr = tf.nn.softmax(predicty3cr)

        # 3, create session and load trained models
        print('create session and load trained models /n')
        model_t0 = time.time()

        # with tf.device("/device:GPU:0"): #0806
        all_vars = tf.global_variables()
        sess = tf.InteractiveSession()
        sess.run(tf.global_variables_initializer())
        if (config_net1):
            net1_vars = [
                x for x in all_vars
                if x.name[0:len(net_name1) + 1] == net_name1 + '/'
            ]
            saver1 = tf.train.Saver(net1_vars)
            saver1.restore(sess, config_net1['model_file'])
        else:
            net1ax_vars = [
                x for x in all_vars
                if x.name[0:len(net_name1ax) + 1] == net_name1ax + '/'
            ]
            saver1ax = tf.train.Saver(net1ax_vars)
            saver1ax.restore(sess, config_net1ax['model_file'])
            net1sg_vars = [
                x for x in all_vars
                if x.name[0:len(net_name1sg) + 1] == net_name1sg + '/'
            ]
            saver1sg = tf.train.Saver(net1sg_vars)
            saver1sg.restore(sess, config_net1sg['model_file'])
            net1cr_vars = [
                x for x in all_vars
                if x.name[0:len(net_name1cr) + 1] == net_name1cr + '/'
            ]
            saver1cr = tf.train.Saver(net1cr_vars)
            saver1cr.restore(sess, config_net1cr['model_file'])

        if (config_test.get('whole_tumor_only', False) is False):  #改动2!
            if (config_net2):
                net2_vars = [
                    x for x in all_vars
                    if x.name[0:len(net_name2) + 1] == net_name2 + '/'
                ]
                saver2 = tf.train.Saver(net2_vars)
                saver2.restore(sess, config_net2['model_file'])
            else:
                net2ax_vars = [
                    x for x in all_vars
                    if x.name[0:len(net_name2ax) + 1] == net_name2ax + '/'
                ]
                saver2ax = tf.train.Saver(net2ax_vars)
                saver2ax.restore(sess, config_net2ax['model_file'])
                net2sg_vars = [
                    x for x in all_vars
                    if x.name[0:len(net_name2sg) + 1] == net_name2sg + '/'
                ]
                saver2sg = tf.train.Saver(net2sg_vars)
                saver2sg.restore(sess, config_net2sg['model_file'])
                net2cr_vars = [
                    x for x in all_vars
                    if x.name[0:len(net_name2cr) + 1] == net_name2cr + '/'
                ]
                saver2cr = tf.train.Saver(net2cr_vars)
                saver2cr.restore(sess, config_net2cr['model_file'])

            if (config_net3):
                net3_vars = [
                    x for x in all_vars
                    if x.name[0:len(net_name3) + 1] == net_name3 + '/'
                ]
                saver3 = tf.train.Saver(net3_vars)
                saver3.restore(sess, config_net3['model_file'])
            else:
                net3ax_vars = [
                    x for x in all_vars
                    if x.name[0:len(net_name3ax) + 1] == net_name3ax + '/'
                ]
                saver3ax = tf.train.Saver(net3ax_vars)
                saver3ax.restore(sess, config_net3ax['model_file'])
                net3sg_vars = [
                    x for x in all_vars
                    if x.name[0:len(net_name3sg) + 1] == net_name3sg + '/'
                ]
                saver3sg = tf.train.Saver(net3sg_vars)
                saver3sg.restore(sess, config_net3sg['model_file'])
                net3cr_vars = [
                    x for x in all_vars
                    if x.name[0:len(net_name3cr) + 1] == net_name3cr + '/'
                ]
                saver3cr = tf.train.Saver(net3cr_vars)
                saver3cr.restore(sess, config_net3cr['model_file'])

        print('Model load time is {}'.format(time.time() - model_t0))

        # 4, load test images
        print('load test images \n')
        load_t0 = time.time()
        dataloader = DataLoader(config_data)
        dataloader.load_data()
        image_num = dataloader.get_total_image_number()
        print('data load time is {}'.format(time.time() - load_t0))

        # 5, start to test
        print('start to test \n')
        test_slice_direction = config_test.get('test_slice_direction', 'all')
        save_folder = config_data['save_folder']
        test_time = []
        struct = ndimage.generate_binary_structure(3, 2)
        margin = config_test.get('roi_patch_margin', 5)

        for i in range(image_num):
            [
                temp_imgs, temp_weight, temp_name, img_names, temp_bbox,
                temp_size
            ] = dataloader.get_image_data_with_name(i)
            t0 = time.time()
            # 5.1, test of 1st network
            if (config_net1):
                data_shapes = [
                    data_shape1[:-1], data_shape1[:-1], data_shape1[:-1]
                ]  #why not use [19, 180, 160] instead of [19, 180, 160, 4] in input
                label_shapes = [
                    label_shape1[:-1], label_shape1[:-1], label_shape1[:-1]
                ]
                nets = [net1, net1, net1]
                outputs = [proby1, proby1, proby1]
                inputs = [x1, x1, x1]
                class_num = class_num1
            else:
                data_shapes = [
                    data_shape1ax[:-1], data_shape1sg[:-1], data_shape1cr[:-1]
                ]
                label_shapes = [
                    label_shape1ax[:-1], label_shape1sg[:-1],
                    label_shape1cr[:-1]
                ]
                nets = [net1ax, net1sg, net1cr]
                outputs = [proby1ax, proby1sg, proby1cr]
                inputs = [x1ax, x1sg, x1cr]
                class_num = class_num1ax
            prob1 = test_one_image_three_nets_adaptive_shape(
                temp_imgs,
                data_shapes,
                label_shapes,
                data_shape1ax[-1],
                class_num,
                batch_size,
                sess,
                nets,
                outputs,
                inputs,
                shape_mode=2)  #average probability of ax,sg,co
            pred1 = np.asarray(np.argmax(prob1, axis=3), np.uint16)
            pred1 = pred1 * temp_weight  #what is the temp_weight

            wt_threshold = 1000
            if (config_test.get('whole_tumor_only', False) is True):  #改动3!
                pred1_lc = ndimage.morphology.binary_closing(pred1,
                                                             structure=struct)
                pred1_lc = get_largest_two_component(pred1_lc, False,
                                                     wt_threshold)
                out_label = pred1_lc
            else:
                # 5.2, test of 2nd network
                if (pred1.sum() == 0):
                    print('net1 output is null', temp_name)
                    bbox1 = get_ND_bounding_box(temp_imgs[0] > 0, margin)
                else:
                    pred1_lc = ndimage.morphology.binary_closing(
                        pred1, structure=struct)
                    pred1_lc = get_largest_two_component(
                        pred1_lc, False, wt_threshold)
                    bbox1 = get_ND_bounding_box(pred1_lc, margin)
                sub_imgs = [
                    crop_ND_volume_with_bounding_box(one_img, bbox1[0],
                                                     bbox1[1])
                    for one_img in temp_imgs
                ]
                sub_weight = crop_ND_volume_with_bounding_box(
                    temp_weight, bbox1[0], bbox1[1])

                if (config_net2):
                    print("Start to testing tumor core")
                    data_shapes = [
                        data_shape2[:-1], data_shape2[:-1], data_shape2[:-1]
                    ]
                    label_shapes = [
                        label_shape2[:-1], label_shape2[:-1], label_shape2[:-1]
                    ]
                    nets = [net2, net2, net2]
                    outputs = [proby2, proby2, proby2]
                    inputs = [x2, x2, x2]
                    class_num = class_num2
                else:
                    data_shapes = [
                        data_shape2ax[:-1], data_shape2sg[:-1],
                        data_shape2cr[:-1]
                    ]
                    label_shapes = [
                        label_shape2ax[:-1], label_shape2sg[:-1],
                        label_shape2cr[:-1]
                    ]
                    nets = [net2ax, net2sg, net2cr]
                    outputs = [proby2ax, proby2sg, proby2cr]
                    inputs = [x2ax, x2sg, x2cr]
                    class_num = class_num2ax
                prob2 = test_one_image_three_nets_adaptive_shape(
                    sub_imgs,
                    data_shapes,
                    label_shapes,
                    data_shape2ax[-1],
                    class_num,
                    batch_size,
                    sess,
                    nets,
                    outputs,
                    inputs,
                    shape_mode=1)
                pred2 = np.asarray(np.argmax(prob2, axis=3), np.uint16)
                pred2 = pred2 * sub_weight

                # 5.3, test of 3rd network
                if (pred2.sum() == 0):
                    print("no tumor core found")
                    [roid, roih, roiw] = sub_imgs[0].shape
                    bbox2 = [[0, 0, 0], [roid - 1, roih - 1, roiw - 1]]
                    subsub_imgs = sub_imgs
                    subsub_weight = sub_weight
                else:
                    print("tumor core exist")
                    pred2_lc = ndimage.morphology.binary_closing(
                        pred2, structure=struct)
                    pred2_lc = get_largest_two_component(pred2_lc)
                    bbox2 = get_ND_bounding_box(pred2_lc, margin)
                    subsub_imgs = [
                        crop_ND_volume_with_bounding_box(
                            one_img, bbox2[0], bbox2[1])
                        for one_img in sub_imgs
                    ]
                    subsub_weight = crop_ND_volume_with_bounding_box(
                        sub_weight, bbox2[0], bbox2[1])

                if (config_net3):
                    print("Start to testing enhancing tumor")
                    data_shapes = [
                        data_shape3[:-1], data_shape3[:-1], data_shape3[:-1]
                    ]
                    label_shapes = [
                        label_shape3[:-1], label_shape3[:-1], label_shape3[:-1]
                    ]
                    nets = [net3, net3, net3]
                    outputs = [proby3, proby3, proby3]
                    inputs = [x3, x3, x3]
                    class_num = class_num3
                else:
                    data_shapes = [
                        data_shape3ax[:-1], data_shape3sg[:-1],
                        data_shape3cr[:-1]
                    ]
                    label_shapes = [
                        label_shape3ax[:-1], label_shape3sg[:-1],
                        label_shape3cr[:-1]
                    ]
                    nets = [net3ax, net3sg, net3cr]
                    outputs = [proby3ax, proby3sg, proby3cr]
                    inputs = [x3ax, x3sg, x3cr]
                    class_num = class_num3ax

                prob3 = test_one_image_three_nets_adaptive_shape(
                    subsub_imgs,
                    data_shapes,
                    label_shapes,
                    data_shape3ax[-1],
                    class_num,
                    batch_size,
                    sess,
                    nets,
                    outputs,
                    inputs,
                    shape_mode=1)

                pred3 = np.asarray(np.argmax(prob3, axis=3), np.uint16)
                pred3 = pred3 * subsub_weight

                # 5.4, fuse results at 3 levels
                # convert subsub_label to full size (non-enhanced)
                label3_roi = np.zeros_like(pred2)
                label3_roi = set_ND_volume_roi_with_bounding_box_range(
                    label3_roi, bbox2[0], bbox2[1], pred3)
                label3 = np.zeros_like(pred1)
                label3 = set_ND_volume_roi_with_bounding_box_range(
                    label3, bbox1[0], bbox1[1], label3_roi)

                label2 = np.zeros_like(pred1)
                label2 = set_ND_volume_roi_with_bounding_box_range(
                    label2, bbox1[0], bbox1[1], pred2)

                label1_mask = (pred1 + label2 + label3) > 0
                label1_mask = ndimage.morphology.binary_closing(
                    label1_mask, structure=struct)
                label1_mask = get_largest_two_component(
                    label1_mask, False, wt_threshold)
                label1 = pred1 * label1_mask

                label2_3_mask = (label2 + label3) > 0
                label2_3_mask = label2_3_mask * label1_mask
                label2_3_mask = ndimage.morphology.binary_closing(
                    label2_3_mask, structure=struct)
                label2_3_mask = remove_external_core(label1, label2_3_mask)
                if (label2_3_mask.sum() > 0):
                    label2_3_mask = get_largest_two_component(label2_3_mask)
                label1 = (label1 + label2_3_mask) > 0
                label2 = label2_3_mask
                label3 = label2 * label3
                vox_3 = np.asarray(label3 > 0, np.float32).sum()
                if (0 < vox_3 and vox_3 < 30):
                    label3 = np.zeros_like(label2)

                # 5.5, convert label and save output
                out_label = label1 * 2
                #if cavity needs to be segmented, label it 3
                if ('brain_cavity' in config_data['modality_postfix']
                        and 'mha' in config_data['file_postfix']):
                    out_label[label2 > 0] = 3
                    out_label[label3 == 1] = 1
                    out_label[label3 == 2] = 4
                #incorporate cavity into the tumor core
                elif ('brain_flair' in config_data['modality_postfix']
                      and 'nii' in config_data['file_postfix']):
                    out_label[label2 > 0] = 1
                    out_label[label3 > 0] = 4
                out_label = np.asarray(out_label, np.int16)

            test_time.append(time.time() - t0)
            final_label = np.zeros(temp_size, np.int16)
            final_label = set_ND_volume_roi_with_bounding_box_range(
                final_label, temp_bbox[0], temp_bbox[1], out_label)
            #Todo check save path's existence, if not, mkdir
            subfolder = f'{save_folder}/{temp_name}'
            print(subfolder)
            if not os.path.exists(subfolder):
                os.makedirs(subfolder)
            save_array_as_nifty_volume(
                final_label, subfolder +
                "/{}_seg_whole.nii.gz".format(temp_name.split('/')[-1]),
                img_names[0])
            tumor_volume = calculate_tumor(
                subfolder +
                "/{}_seg_whole.nii.gz".format(temp_name.split('/')[-1]))
            #print tumor volume report for each case
            volume_report = f"The quantitative volumetry report of accession number {temp_name.split('/')[-1]} suggests total tumor volume" \
                           f" was {tumor_volume['total tumor volume']} {tumor_volume['unit']} " \
                           f"(enhancing portion was {tumor_volume['enhancing portion']} {tumor_volume['unit']}; " \
                           f"non enhancing portion was {tumor_volume['non enhancing portion']} {tumor_volume['unit']})" \
                           f"and total vasogenic edema volume was {tumor_volume['total vasogenic edema volume']} {tumor_volume['unit']}. \n"
            print(volume_report)
        test_time = np.asarray(test_time)
        print('test time', test_time.mean())
        sess.close()
Beispiel #13
0
def train(config_file):
    # 1, load configuration parameters
    config = parse_config(config_file)
    config_common = config['common']
    config_data = config['data']
    config_net = config['network']
    config_train = config['training']

    random.seed(config_train.get('random_seed', 1))
    assert (config_data['with_ground_truth'])

    class_num = config_net['class_num']

    # 2 load data
    data_names = get_patient_names(config_data["data_names"])
    split_ratio = int(config_train["train_val_ratio"] * len(data_names))
    random.Random(42).shuffle(data_names)
    config_data["train_names"] = data_names[:split_ratio]
    config_data["val_names"] = data_names[split_ratio:]
    dataset_tr = DataLoader("train", config_data)
    dataset_tr.load_data()
    train_loader = torch.utils.data.DataLoader(
        dataset_tr,
        batch_size=config_train['train_batch_size'],
        shuffle=True,
        num_workers=4,
        pin_memory=True)

    dataset_val = DataLoader("validate", config_data)
    dataset_val.load_data()
    val_loader = torch.utils.data.DataLoader(
        dataset_val,
        batch_size=config_train['val_batch_size'],
        shuffle=False,
        num_workers=4,
        pin_memory=True)

    # 3, load model
    # load pretrained
    empty_model = ARCNet(class_num,
                         vae_enable=config_train['vae_enable'],
                         config=config_data)
    if config_train['model_pre_trained']:
        arcnet_model = torch.load(config_train['model_pre_trained'])
    else:
        arcnet_model = ARCNet(class_num,
                              vae_enable=config_train['vae_enable'],
                              config=config_data)

    # 4, start to train
    solver = Solver(
        arcnet_model,
        exp_name=config_train['exp_name'],
        device=config_common['device'],
        class_num=config_net['class_num'],
        optim_args={
            "lr": config_train['learning_rate'],
            "betas": config_train['optim_betas'],
            "eps": config_train['optim_eps'],
            "weight_decay": config_train['optim_weight_decay']
        },
        loss_args={
            "vae_loss": config_train['vae_enable'],
            "loss_k1_weight": config_train['loss_k1_weight'],
            "loss_k2_weight": config_train['loss_k2_weight']
        },
        model_name=config_common['model_name'],
        labels=config_data['labels'],
        log_nth=config_train['log_nth'],
        num_epochs=config_train['num_epochs'],
        lr_scheduler_step_size=config_train['lr_scheduler_step_size'],
        lr_scheduler_gamma=config_train['lr_scheduler_gamma'],
        use_last_checkpoint=config_train['use_last_checkpoint'],
        log_dir=config_common['log_dir'],
        exp_dir=config_common['exp_dir'])

    solver.train(train_loader, val_loader)
    if not os.path.exists(config_common['save_model_dir']):
        os.makedirs(config_common['save_model_dir'])
    final_model_path = os.path.join(config_common['save_model_dir'],
                                    config_train['final_model_file'])
    solver.model = empty_model
    solver.save_best_model(final_model_path)
    print("final model saved @ " + str(final_model_path))
Beispiel #14
0
def run(stage, config_file):
    # construct graph
    config = parse_config(config_file)
    config_data = config['data']
    config_net = config['network']
    config_train = config['training']
    config_test = config['testing']

    if (stage == 'train'):
        random.seed(config_train.get('random_seed', 1))
        assert (config_data['with_ground_truth'])

    net_type = config_net['net_type']
    net_name = config_net['net_name']
    data_shape = config_net['data_shape']
    label_shape = config_net['label_shape']
    data_channel = config_net['data_channel']
    class_num = config_net['class_num']
    batch_size = config_data.get('batch_size', 5)

    # construct graph
    print('data_channel', data_channel)
    full_data_shape = [batch_size] + data_shape + [data_channel]
    full_label_shape = [batch_size] + label_shape + [1]
    x = tf.placeholder(tf.float32, shape=full_data_shape)
    w = tf.placeholder(tf.float32, shape=full_label_shape)
    y = tf.placeholder(tf.int64, shape=full_label_shape)

    w_regularizer = None
    b_regularizer = None
    if (stage == 'train'):
        w_regularizer = regularizers.l2_regularizer(
            config_train.get('decay', 1e-7))
        b_regularizer = regularizers.l2_regularizer(
            config_train.get('decay', 1e-7))
    net_class = NetFactory.create(net_type)
    net = net_class(num_classes=class_num,
                    w_regularizer=w_regularizer,
                    b_regularizer=b_regularizer,
                    name=net_name)

    if (net_type == 'MSNet'):
        net.set_params(config_net)

    loss_func = LossFunction(n_class=class_num)
    predicty = net(x, is_training=True)
    proby = tf.nn.softmax(predicty)
    loss = loss_func(predicty, y, weight_map=w)
    print('size of predicty:', predicty)

    # Initialize session and saver
    if (stage == 'train'):
        lr = config_train.get('learning_rate', 1e-3)
        opt_step = tf.train.AdamOptimizer(lr).minimize(loss)
    sess = tf.InteractiveSession()
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()

    loader = DataLoader()
    loader.set_params(config_data)
    loader.load_data()

    if (stage == 'train'):
        loss_list = []
        loss_file = config_train['model_save_prefix'] + "_loss.txt"
        start_it = config_train.get('start_iteration', 0)
        if (start_it > 0):
            saver.restore(sess, config_train['model_pre_trained'])
        for n in range(start_it, config_train['maximal_iteration']):
            train_pair = loader.get_subimage_batch()
            tempx = train_pair['images']
            tempw = train_pair['weights']
            tempy = train_pair['labels']
            opt_step.run(session=sess,
                         feed_dict={
                             x: tempx,
                             w: tempw,
                             y: tempy
                         })

            if (n % config_train['test_iteration'] == 0):
                batch_dice_list = []
                for step in range(config_train['test_step']):
                    train_pair = loader.get_subimage_batch()
                    tempx = train_pair['images']
                    tempw = train_pair['weights']
                    tempy = train_pair['labels']
                    tempp = train_pair['probs']
                    dice = loss.eval(feed_dict={x: tempx, w: tempw, y: tempy})
                    batch_dice_list.append(dice)
                batch_dice = np.asarray(batch_dice_list, np.float32).mean()
                t = time.strftime('%X %x %Z')
                print(t, 'n', n, 'loss', batch_dice)
                loss_list.append(batch_dice)
                np.savetxt(loss_file, np.asarray(loss_list))
            if ((n + 1) % config_train['snapshot_iteration'] == 0):
                saver.save(
                    sess, config_train['model_save_prefix'] +
                    "_{0:}.ckpt".format(n + 1))
    else:
        saver.restore(sess, config_test['model_file'])
        image_num = loader.get_total_image_number()
        test_slice_direction = config_test.get('test_slice_direction', 'all')
        save_folder = config_test['save_folder']
        test_time = []
        for i in range(image_num):
            [test_imgs, test_weight,
             temp_name] = loader.get_image_data_with_name(i)
            down_sample = config_test.get('down_sample_rate', 1.0)
            if (down_sample == 1.0):
                temp_imgs = test_imgs
                temp_weight = test_weight
            else:
                temp_imgs = []
                for mod in range(len(test_imgs)):
                    temp_imgs.append(
                        ndimage.interpolation.zoom(test_imgs[mod],
                                                   1.0 / down_sample,
                                                   order=1))
                temp_weight = ndimage.interpolation.zoom(test_weight,
                                                         1.0 / down_sample,
                                                         order=1)
            t0 = time.time()
            if (net_type == 'HighRes3DNet' or net_type == 'UNet3D'):
                temp_prob = volume_probability_prediction_3d_roi(temp_imgs, data_shape, \
                    label_shape, data_channel, class_num, batch_size, sess, proby, x)
            else:
                temp_prob = test_one_image(temp_imgs, data_shape, label_shape,
                                           data_channel, class_num, batch_size,
                                           test_slice_direction, sess, proby,
                                           x)
            temp_time = time.time() - t0
            test_time.append(temp_time)
            temp_label = np.asarray(np.argmax(temp_prob, axis=3), np.uint16)
            temp_label[temp_weight == 0] = 0
            label_convert_source = config_test.get('label_convert_source',
                                                   None)
            label_convert_target = config_test.get('label_convert_target',
                                                   None)
            if (label_convert_source and label_convert_target):
                assert (len(label_convert_source) == len(label_convert_target))
                temp_label = convert_label(temp_label, label_convert_source,
                                           label_convert_target)
            if (down_sample != 1.0):
                temp_label = resize_3D_volume_to_given_shape(temp_label,
                                                             test_weight.shape,
                                                             order=0)
            save_array_as_nifty_volume(
                temp_label, save_folder + "/{0:}.nii.gz".format(temp_name))
            # save probability
            save_prob = config_test.get('save_prob', False)
            if (save_prob):
                fg_prob = temp_prob[:, :, :, 1]
                fg_prob = np.reshape(fg_prob, temp_label.shape)
                save_array_as_nifty_volume(
                    fg_prob,
                    save_folder + "_prob/{0:}.nii.gz".format(temp_name))


#             pickle.dump(temp_prob, open(save_folder+"_prob/{0:}.p".format(temp_name), 'w'))
        test_time = np.asarray(test_time)
        print('test time', test_time.mean(), test_time.std())
        np.savetxt(save_folder + '/test_time.txt', test_time)

    sess.close()
def produce_adv(config_file):
    # 1, load configuration parameters
    config = parse_config(config_file)
    config_adv = config['adversarial']
    config_data = config['data']
    config_net = config['network']
    config_train = config['training']

    random.seed(config_train.get('random_seed', 1))
    assert (config_data['with_ground_truth'])

    net_type = config_net['net_type']
    net_name = config_net['net_name']
    class_num = config_net['class_num']
    batch_size = config_data.get('batch_size', 5)

    # 2, construct graph
    print("Constructing graph...")
    full_data_shape = [batch_size] + config_data['data_shape']
    full_label_shape = [batch_size] + config_data['label_shape']
    x = tf.placeholder(tf.float32, shape=full_data_shape)
    w = tf.placeholder(tf.float32, shape=full_label_shape)
    y = tf.placeholder(tf.int64, shape=full_label_shape)

    w_regularizer = regularizers.l2_regularizer(config_train.get(
        'decay', 1e-7))
    b_regularizer = regularizers.l2_regularizer(config_train.get(
        'decay', 1e-7))
    net_class = NetFactory.create(net_type)
    net = net_class(num_classes=class_num,
                    w_regularizer=w_regularizer,
                    b_regularizer=b_regularizer,
                    name=net_name)
    net.set_params(config_net)
    predicty = net(x, is_training=True)
    proby = tf.nn.softmax(predicty)

    loss_func = LossFunction(n_class=class_num)
    loss = loss_func(predicty, y, weight_map=w)
    print("old y = %s" % str(y))
    print("old logits = %s" % str(predicty))
    print('size of predicty:', predicty)

    loss_dists = []
    # 3, initialize session and saver
    with tf.Session() as sess:
        saver = tf.train.Saver()
        adv_method = config_adv['method']
        if adv_method == 'FastGradientMethod':
            attack = FastGradientMethod(net)
        elif adv_method == 'MomentumIterativeMethod':
            attack = MomentumIterativeMethod(net)
        elif adv_method == 'ProjectedGradientDescent':
            attack = ProjectedGradientDescent(net)
        elif adv_method == 'GaussianNoise':
            attack = GaussianNoise(net)
        else:
            print("Unknown adversary %s" % adv_method)
            exit()
        #fgsm_params = {'eps': 0.01}
        #adv_x = fgsm.generate(x, **fgsm_params)

        adv_steps = config_adv[
            'iterations'] if adv_method == 'FastGradientMethod' else 1
        eps_vals = np.arange(0.1, 1.5, 0.2)
        for adv_eps in eps_vals:
            #adv_eps = config_adv['eps']
            adv_eps = float(adv_eps)

            if config_adv['targeted']:
                # Targeted attack to make NN label as no tumor
                params = {
                    'eps':
                    adv_eps / adv_steps,
                    'y_target':
                    tf.zeros_like(y),
                    'loss_func':
                    lambda logits, labels: loss_func(
                        logits, labels, weight_map=w)
                }
            else:
                params = {
                    'eps':
                    adv_eps / adv_steps,
                    'y':
                    y,
                    'loss_func':
                    lambda logits, labels: loss_func(
                        logits, labels, weight_map=w)
                }

            if adv_method != 'FastGradientMethod':
                params['nb_iter'] = config_adv['iterations']

            print("Generating %s attack with method %s; eps=%f iterations=%d" %
                  ("targeted" if config_adv['targeted'] else "", adv_method,
                   adv_eps, config_adv['iterations']))
            print("Params: %s" % str(params))
            adv_x = attack.generate(x, **params)

            sess.run(tf.global_variables_initializer())

            dataloader = DataLoader(config_data)
            dataloader.load_data()

            # 4, start to train
            print("Restoring model...")
            loss_file = config_train['model_save_prefix'] + "_loss.txt"
            saver.restore(sess, config_train['model_pre_trained'])

            num_images = dataloader.get_total_image_number()
            print("Running adversary on %d images" % num_images)
            # calculated dice scores
            batch_dice_list = []
            batch_dice_list_adv = []
            batch_dice_list_diff = []
            for n in range(num_images):
                print("FGSM img " + str(n + 1))
                #[temp_imgs, temp_weight, temp_name, img_names, temp_bbox, temp_size] = dataloader.get_image_data_with_name(i)
                train_pair = dataloader.get_subimage_batch()
                tempx = train_pair['images']
                adv = np.copy(tempx)
                tempw = train_pair['weights']
                tempy = train_pair['labels']

                for i in range(adv_steps):
                    with sess.as_default():
                        adv = sess.run(adv_x,
                                       feed_dict={
                                           x: adv,
                                           y: tempy,
                                           w: tempw
                                       })

                print("Saving inputs and outputs...")

                for i in range(batch_size):
                    save_array_as_nifty_volume(
                        tempx[i], config_adv['save_folder'] +
                        "/img_{0:}.nii.gz".format(n * batch_size + i))
                    save_array_as_nifty_volume(
                        adv[i], config_adv['save_folder'] +
                        "/img_{0:}_adv.nii.gz".format(n * batch_size + i))

                label_og = sess.run(proby, feed_dict={x: tempx, w: tempw})
                label_adv = sess.run(proby, feed_dict={x: adv, w: tempw})

                save_label = np.asarray(tempy, np.float32)
                for i in range(batch_size):
                    save_array_as_nifty_volume(
                        label_og[i], config_adv['save_folder'] +
                        "/label_{0:}.nii.gz".format(n * batch_size + i))
                    save_array_as_nifty_volume(
                        label_adv[i], config_adv['save_folder'] +
                        "/label_{0:}_adv.nii.gz".format(n * batch_size + i))
                    save_array_as_nifty_volume(
                        save_label[i], config_adv['save_folder'] +
                        "/label_{0:}_true.nii.gz".format(n * batch_size + i))

                # Calculate dice scores
                loss_tempx = loss.eval(feed_dict={
                    x: tempx,
                    w: tempw,
                    y: tempy
                })
                loss_adv = loss.eval(feed_dict={x: adv, w: tempw, y: tempy})
                print("OG loss: %f Adv loss: %f" % (loss_tempx, loss_adv))
                batch_dice_list.append(loss_tempx)
                batch_dice_list_adv.append(loss_adv)
                batch_dice_list_diff.append(loss_adv - loss_tempx)

            batch_dice_og = np.asarray(batch_dice_list, np.float32).mean()
            batch_dice_adv = np.asarray(batch_dice_list_adv, np.float32).mean()
            t = time.strftime('%X %x %Z')
            print(t, 'n', n, 'loss_og', batch_dice_og, 'loss_adv',
                  batch_dice_adv)
            loss_dists.append(batch_dice_list_diff)
        plot(loss_dists, eps_vals, 'Projected Gradient Descent')
Beispiel #16
0
def perform_evaluation(config_file):
    # 1, load configure file
    config = parse_config(config_file)
    config_data = config['data']
    config_net1 = config.get('network1', None)
    config_net2 = config.get('network2', None)
    config_net3 = config.get('network3', None)
    config_test = config['testing']
    batch_size = config_test.get('batch_size', 5)

    # 2.1, network for whole tumor
    if (config_net1):
        net_type1 = config_net1['net_type']
        net_name1 = config_net1['net_name']
        data_shape1 = config_net1['data_shape']
        label_shape1 = config_net1['label_shape']
        class_num1 = config_net1['class_num']

        # construct graph for 1st network
        full_data_shape1 = [batch_size] + data_shape1
        x1 = tf.placeholder(tf.float32, shape=full_data_shape1)
        net_class1 = NetFactory.create(net_type1)
        net1 = net_class1(num_classes=class_num1,
                          w_regularizer=None,
                          b_regularizer=None,
                          name=net_name1)
        net1.set_params(config_net1)
        predicty1 = net1(x1, is_training=True)
        proby1 = tf.nn.softmax(predicty1)
    else:
        config_net1ax = config['network1ax']
        config_net1sg = config['network1sg']
        config_net1cr = config['network1cr']

        # construct graph for 1st network axial
        net_type1ax = config_net1ax['net_type']
        net_name1ax = config_net1ax['net_name']
        data_shape1ax = config_net1ax['data_shape']
        label_shape1ax = config_net1ax['label_shape']
        class_num1ax = config_net1ax['class_num']

        full_data_shape1ax = [batch_size] + data_shape1ax
        x1ax = tf.placeholder(tf.float32, shape=full_data_shape1ax)
        net_class1ax = NetFactory.create(net_type1ax)
        net1ax = net_class1ax(num_classes=class_num1ax,
                              w_regularizer=None,
                              b_regularizer=None,
                              name=net_name1ax)
        net1ax.set_params(config_net1ax)
        predicty1ax = net1ax(x1ax, is_training=True)
        proby1ax = tf.nn.softmax(predicty1ax)

        # construct graph for 1st network sagittal
        net_type1sg = config_net1sg['net_type']
        net_name1sg = config_net1sg['net_name']
        data_shape1sg = config_net1sg['data_shape']
        label_shape1sg = config_net1sg['label_shape']
        class_num1sg = config_net1sg['class_num']

        full_data_shape1sg = [batch_size] + data_shape1sg
        x1sg = tf.placeholder(tf.float32, shape=full_data_shape1sg)
        net_class1sg = NetFactory.create(net_type1sg)
        net1sg = net_class1sg(num_classes=class_num1sg,
                              w_regularizer=None,
                              b_regularizer=None,
                              name=net_name1sg)
        net1sg.set_params(config_net1sg)
        predicty1sg = net1sg(x1sg, is_training=True)
        proby1sg = tf.nn.softmax(predicty1sg)

        # construct graph for 1st network corogal
        net_type1cr = config_net1cr['net_type']
        net_name1cr = config_net1cr['net_name']
        data_shape1cr = config_net1cr['data_shape']
        label_shape1cr = config_net1cr['label_shape']
        class_num1cr = config_net1cr['class_num']

        full_data_shape1cr = [batch_size] + data_shape1cr
        x1cr = tf.placeholder(tf.float32, shape=full_data_shape1cr)
        net_class1cr = NetFactory.create(net_type1cr)
        net1cr = net_class1cr(num_classes=class_num1cr,
                              w_regularizer=None,
                              b_regularizer=None,
                              name=net_name1cr)
        net1cr.set_params(config_net1cr)
        predicty1cr = net1cr(x1cr, is_training=True)
        proby1cr = tf.nn.softmax(predicty1cr)

    if (config_test.get('whole_tumor_only', False) is False):
        # 2.2, networks for tumor core
        if (config_net2):
            net_type2 = config_net2['net_type']
            net_name2 = config_net2['net_name']
            data_shape2 = config_net2['data_shape']
            label_shape2 = config_net2['label_shape']
            class_num2 = config_net2['class_num']

            # construct graph for 2st network
            full_data_shape2 = [batch_size] + data_shape2
            x2 = tf.placeholder(tf.float32, shape=full_data_shape2)
            net_class2 = NetFactory.create(net_type2)
            net2 = net_class2(num_classes=class_num2,
                              w_regularizer=None,
                              b_regularizer=None,
                              name=net_name2)
            net2.set_params(config_net2)
            predicty2 = net2(x2, is_training=True)
            proby2 = tf.nn.softmax(predicty2)
        else:
            config_net2ax = config['network2ax']
            config_net2sg = config['network2sg']
            config_net2cr = config['network2cr']

            # construct graph for 2st network axial
            net_type2ax = config_net2ax['net_type']
            net_name2ax = config_net2ax['net_name']
            data_shape2ax = config_net2ax['data_shape']
            label_shape2ax = config_net2ax['label_shape']
            class_num2ax = config_net2ax['class_num']

            full_data_shape2ax = [batch_size] + data_shape2ax
            x2ax = tf.placeholder(tf.float32, shape=full_data_shape2ax)
            net_class2ax = NetFactory.create(net_type2ax)
            net2ax = net_class2ax(num_classes=class_num2ax,
                                  w_regularizer=None,
                                  b_regularizer=None,
                                  name=net_name2ax)
            net2ax.set_params(config_net2ax)
            predicty2ax = net2ax(x2ax, is_training=True)
            proby2ax = tf.nn.softmax(predicty2ax)

            # construct graph for 2st network sagittal
            net_type2sg = config_net2sg['net_type']
            net_name2sg = config_net2sg['net_name']
            data_shape2sg = config_net2sg['data_shape']
            label_shape2sg = config_net2sg['label_shape']
            class_num2sg = config_net2sg['class_num']

            full_data_shape2sg = [batch_size] + data_shape2sg
            x2sg = tf.placeholder(tf.float32, shape=full_data_shape2sg)
            net_class2sg = NetFactory.create(net_type2sg)
            net2sg = net_class2sg(num_classes=class_num2sg,
                                  w_regularizer=None,
                                  b_regularizer=None,
                                  name=net_name2sg)
            net2sg.set_params(config_net2sg)
            predicty2sg = net2sg(x2sg, is_training=True)
            proby2sg = tf.nn.softmax(predicty2sg)

            # construct graph for 2st network corogal
            net_type2cr = config_net2cr['net_type']
            net_name2cr = config_net2cr['net_name']
            data_shape2cr = config_net2cr['data_shape']
            label_shape2cr = config_net2cr['label_shape']
            class_num2cr = config_net2cr['class_num']

            full_data_shape2cr = [batch_size] + data_shape2cr
            x2cr = tf.placeholder(tf.float32, shape=full_data_shape2cr)
            net_class2cr = NetFactory.create(net_type2cr)
            net2cr = net_class2cr(num_classes=class_num2cr,
                                  w_regularizer=None,
                                  b_regularizer=None,
                                  name=net_name2cr)
            net2cr.set_params(config_net2cr)
            predicty2cr = net2cr(x2cr, is_training=True)
            proby2cr = tf.nn.softmax(predicty2cr)

        # 2.3, networks for enhanced tumor
        if (config_net3):
            net_type3 = config_net3['net_type']
            net_name3 = config_net3['net_name']
            data_shape3 = config_net3['data_shape']
            label_shape3 = config_net3['label_shape']
            class_num3 = config_net3['class_num']

            # construct graph for 3st network
            full_data_shape3 = [batch_size] + data_shape3
            x3 = tf.placeholder(tf.float32, shape=full_data_shape3)
            net_class3 = NetFactory.create(net_type3)
            net3 = net_class3(num_classes=class_num3,
                              w_regularizer=None,
                              b_regularizer=None,
                              name=net_name3)
            net3.set_params(config_net3)
            predicty3 = net3(x3, is_training=True)
            proby3 = tf.nn.softmax(predicty3)
        else:
            config_net3ax = config['network3ax']
            config_net3sg = config['network3sg']
            config_net3cr = config['network3cr']

            # construct graph for 3st network axial
            net_type3ax = config_net3ax['net_type']
            net_name3ax = config_net3ax['net_name']
            data_shape3ax = config_net3ax['data_shape']
            label_shape3ax = config_net3ax['label_shape']
            class_num3ax = config_net3ax['class_num']

            full_data_shape3ax = [batch_size] + data_shape3ax
            x3ax = tf.placeholder(tf.float32, shape=full_data_shape3ax)
            net_class3ax = NetFactory.create(net_type3ax)
            net3ax = net_class3ax(num_classes=class_num3ax,
                                  w_regularizer=None,
                                  b_regularizer=None,
                                  name=net_name3ax)
            net3ax.set_params(config_net3ax)
            predicty3ax = net3ax(x3ax, is_training=True)
            proby3ax = tf.nn.softmax(predicty3ax)

            # construct graph for 3st network sagittal
            net_type3sg = config_net3sg['net_type']
            net_name3sg = config_net3sg['net_name']
            data_shape3sg = config_net3sg['data_shape']
            label_shape3sg = config_net3sg['label_shape']
            class_num3sg = config_net3sg['class_num']
            # construct graph for 3st network
            full_data_shape3sg = [batch_size] + data_shape3sg
            x3sg = tf.placeholder(tf.float32, shape=full_data_shape3sg)
            net_class3sg = NetFactory.create(net_type3sg)
            net3sg = net_class3sg(num_classes=class_num3sg,
                                  w_regularizer=None,
                                  b_regularizer=None,
                                  name=net_name3sg)
            net3sg.set_params(config_net3sg)
            predicty3sg = net3sg(x3sg, is_training=True)
            proby3sg = tf.nn.softmax(predicty3sg)

            # construct graph for 3st network corogal
            net_type3cr = config_net3cr['net_type']
            net_name3cr = config_net3cr['net_name']
            data_shape3cr = config_net3cr['data_shape']
            label_shape3cr = config_net3cr['label_shape']
            class_num3cr = config_net3cr['class_num']
            # construct graph for 3st network
            full_data_shape3cr = [batch_size] + data_shape3cr
            x3cr = tf.placeholder(tf.float32, shape=full_data_shape3cr)
            net_class3cr = NetFactory.create(net_type3cr)
            net3cr = net_class3cr(num_classes=class_num3cr,
                                  w_regularizer=None,
                                  b_regularizer=None,
                                  name=net_name3cr)
            net3cr.set_params(config_net3cr)
            predicty3cr = net3cr(x3cr, is_training=True)
            proby3cr = tf.nn.softmax(predicty3cr)

    # 3, create session and load trained models
    all_vars = tf.global_variables()
    sess = tf.InteractiveSession()
    sess.run(tf.global_variables_initializer())
    if (config_net1):
        net1_vars = [
            x for x in all_vars
            if x.name[0:len(net_name1) + 1] == net_name1 + '/'
        ]
        saver1 = tf.train.Saver(net1_vars)
        saver1.restore(sess, config_net1['model_file'])
    else:
        net1ax_vars = [
            x for x in all_vars
            if x.name[0:len(net_name1ax) + 1] == net_name1ax + '/'
        ]
        saver1ax = tf.train.Saver(net1ax_vars)
        saver1ax.restore(sess, config_net1ax['model_file'])
        net1sg_vars = [
            x for x in all_vars
            if x.name[0:len(net_name1sg) + 1] == net_name1sg + '/'
        ]
        saver1sg = tf.train.Saver(net1sg_vars)
        saver1sg.restore(sess, config_net1sg['model_file'])
        net1cr_vars = [
            x for x in all_vars
            if x.name[0:len(net_name1cr) + 1] == net_name1cr + '/'
        ]
        saver1cr = tf.train.Saver(net1cr_vars)
        saver1cr.restore(sess, config_net1cr['model_file'])

    if (config_test.get('whole_tumor_only', False) is False):
        if (config_net2):
            net2_vars = [
                x for x in all_vars
                if x.name[0:len(net_name2) + 1] == net_name2 + '/'
            ]
            saver2 = tf.train.Saver(net2_vars)
            saver2.restore(sess, config_net2['model_file'])
        else:
            net2ax_vars = [
                x for x in all_vars
                if x.name[0:len(net_name2ax) + 1] == net_name2ax + '/'
            ]
            saver2ax = tf.train.Saver(net2ax_vars)
            saver2ax.restore(sess, config_net2ax['model_file'])
            net2sg_vars = [
                x for x in all_vars
                if x.name[0:len(net_name2sg) + 1] == net_name2sg + '/'
            ]
            saver2sg = tf.train.Saver(net2sg_vars)
            saver2sg.restore(sess, config_net2sg['model_file'])
            net2cr_vars = [
                x for x in all_vars
                if x.name[0:len(net_name2cr) + 1] == net_name2cr + '/'
            ]
            saver2cr = tf.train.Saver(net2cr_vars)
            saver2cr.restore(sess, config_net2cr['model_file'])

        if (config_net3):
            net3_vars = [
                x for x in all_vars
                if x.name[0:len(net_name3) + 1] == net_name3 + '/'
            ]
            saver3 = tf.train.Saver(net3_vars)
            saver3.restore(sess, config_net3['model_file'])
        else:
            net3ax_vars = [
                x for x in all_vars
                if x.name[0:len(net_name3ax) + 1] == net_name3ax + '/'
            ]
            saver3ax = tf.train.Saver(net3ax_vars)
            saver3ax.restore(sess, config_net3ax['model_file'])
            net3sg_vars = [
                x for x in all_vars
                if x.name[0:len(net_name3sg) + 1] == net_name3sg + '/'
            ]
            saver3sg = tf.train.Saver(net3sg_vars)
            saver3sg.restore(sess, config_net3sg['model_file'])
            net3cr_vars = [
                x for x in all_vars
                if x.name[0:len(net_name3cr) + 1] == net_name3cr + '/'
            ]
            saver3cr = tf.train.Saver(net3cr_vars)
            saver3cr.restore(sess, config_net3cr['model_file'])

    # 4, load test images
    dataloader = DataLoader(config_data)
    dataloader.load_data()
    image_num = dataloader.get_total_image_number()

    # 5, start to test
    test_slice_direction = config_test.get('test_slice_direction', 'all')
    save_folder = config_data['save_folder']
    test_time = []
    struct = ndimage.generate_binary_structure(3, 2)
    margin = config_test.get('roi_patch_margin', 5)

    for i in range(image_num):
        [temp_imgs, temp_weight, temp_name, img_names, temp_bbox,
         temp_size] = dataloader.get_image_data_with_name(i)
        t0 = time.time()
        # 5.1, test of 1st network
        if (config_net1):
            data_shapes = [
                data_shape1[:-1], data_shape1[:-1], data_shape1[:-1]
            ]
            label_shapes = [
                label_shape1[:-1], label_shape1[:-1], label_shape1[:-1]
            ]
            nets = [net1, net1, net1]
            outputs = [proby1, proby1, proby1]
            inputs = [x1, x1, x1]
            class_num = class_num1
        else:
            data_shapes = [
                data_shape1ax[:-1], data_shape1sg[:-1], data_shape1cr[:-1]
            ]
            label_shapes = [
                label_shape1ax[:-1], label_shape1sg[:-1], label_shape1cr[:-1]
            ]
            nets = [net1ax, net1sg, net1cr]
            outputs = [proby1ax, proby1sg, proby1cr]
            inputs = [x1ax, x1sg, x1cr]
            class_num = class_num1ax
        prob1 = test_one_image_three_nets_adaptive_shape(temp_imgs,
                                                         data_shapes,
                                                         label_shapes,
                                                         data_shape1ax[-1],
                                                         class_num,
                                                         batch_size,
                                                         sess,
                                                         nets,
                                                         outputs,
                                                         inputs,
                                                         shape_mode=2)
        pred1 = np.asarray(np.argmax(prob1, axis=3), np.uint16)
        pred1 = pred1 * temp_weight

        wt_threshold = 2000
        if (config_test.get('whole_tumor_only', False) is True):
            pred1_lc = ndimage.morphology.binary_closing(pred1,
                                                         structure=struct)
            pred1_lc = get_largest_two_component(pred1_lc, False, wt_threshold)
            out_label = pred1_lc
        else:
            # 5.2, test of 2nd network
            if (pred1.sum() == 0):
                print('net1 output is null', temp_name)
                bbox1 = get_ND_bounding_box(temp_imgs[0] > 0, margin)
            else:
                pred1_lc = ndimage.morphology.binary_closing(pred1,
                                                             structure=struct)
                pred1_lc = get_largest_two_component(pred1_lc, False,
                                                     wt_threshold)
                bbox1 = get_ND_bounding_box(pred1_lc, margin)
            sub_imgs = [
                crop_ND_volume_with_bounding_box(one_img, bbox1[0], bbox1[1])
                for one_img in temp_imgs
            ]
            sub_weight = crop_ND_volume_with_bounding_box(
                temp_weight, bbox1[0], bbox1[1])

            if (config_net2):
                data_shapes = [
                    data_shape2[:-1], data_shape2[:-1], data_shape2[:-1]
                ]
                label_shapes = [
                    label_shape2[:-1], label_shape2[:-1], label_shape2[:-1]
                ]
                nets = [net2, net2, net2]
                outputs = [proby2, proby2, proby2]
                inputs = [x2, x2, x2]
                class_num = class_num2
            else:
                data_shapes = [
                    data_shape2ax[:-1], data_shape2sg[:-1], data_shape2cr[:-1]
                ]
                label_shapes = [
                    label_shape2ax[:-1], label_shape2sg[:-1],
                    label_shape2cr[:-1]
                ]
                nets = [net2ax, net2sg, net2cr]
                outputs = [proby2ax, proby2sg, proby2cr]
                inputs = [x2ax, x2sg, x2cr]
                class_num = class_num2ax
            prob2 = test_one_image_three_nets_adaptive_shape(sub_imgs,
                                                             data_shapes,
                                                             label_shapes,
                                                             data_shape2ax[-1],
                                                             class_num,
                                                             batch_size,
                                                             sess,
                                                             nets,
                                                             outputs,
                                                             inputs,
                                                             shape_mode=1)
            pred2 = np.asarray(np.argmax(prob2, axis=3), np.uint16)
            pred2 = pred2 * sub_weight

            # 5.3, test of 3rd network
            if (pred2.sum() == 0):
                [roid, roih, roiw] = sub_imgs[0].shape
                bbox2 = [[0, 0, 0], [roid - 1, roih - 1, roiw - 1]]
                subsub_imgs = sub_imgs
                subsub_weight = sub_weight
            else:
                pred2_lc = ndimage.morphology.binary_closing(pred2,
                                                             structure=struct)
                pred2_lc = get_largest_two_component(pred2_lc)
                bbox2 = get_ND_bounding_box(pred2_lc, margin)
                subsub_imgs = [
                    crop_ND_volume_with_bounding_box(one_img, bbox2[0],
                                                     bbox2[1])
                    for one_img in sub_imgs
                ]
                subsub_weight = crop_ND_volume_with_bounding_box(
                    sub_weight, bbox2[0], bbox2[1])

            if (config_net3):
                data_shapes = [
                    data_shape3[:-1], data_shape3[:-1], data_shape3[:-1]
                ]
                label_shapes = [
                    label_shape3[:-1], label_shape3[:-1], label_shape3[:-1]
                ]
                nets = [net3, net3, net3]
                outputs = [proby3, proby3, proby3]
                inputs = [x3, x3, x3]
                class_num = class_num3
            else:
                data_shapes = [
                    data_shape3ax[:-1], data_shape3sg[:-1], data_shape3cr[:-1]
                ]
                label_shapes = [
                    label_shape3ax[:-1], label_shape3sg[:-1],
                    label_shape3cr[:-1]
                ]
                nets = [net3ax, net3sg, net3cr]
                outputs = [proby3ax, proby3sg, proby3cr]
                inputs = [x3ax, x3sg, x3cr]
                class_num = class_num3ax

            prob3 = test_one_image_three_nets_adaptive_shape(subsub_imgs,
                                                             data_shapes,
                                                             label_shapes,
                                                             data_shape3ax[-1],
                                                             class_num,
                                                             batch_size,
                                                             sess,
                                                             nets,
                                                             outputs,
                                                             inputs,
                                                             shape_mode=1)

            pred3 = np.asarray(np.argmax(prob3, axis=3), np.uint16)
            pred3 = pred3 * subsub_weight

            # 5.4, fuse results at 3 levels
            # convert subsub_label to full size (non-enhanced)
            label3_roi = np.zeros_like(pred2)
            label3_roi = set_ND_volume_roi_with_bounding_box_range(
                label3_roi, bbox2[0], bbox2[1], pred3)
            label3 = np.zeros_like(pred1)
            label3 = set_ND_volume_roi_with_bounding_box_range(
                label3, bbox1[0], bbox1[1], label3_roi)

            label2 = np.zeros_like(pred1)
            label2 = set_ND_volume_roi_with_bounding_box_range(
                label2, bbox1[0], bbox1[1], pred2)

            label1_mask = (pred1 + label2 + label3) > 0
            label1_mask = ndimage.morphology.binary_closing(label1_mask,
                                                            structure=struct)
            label1_mask = get_largest_two_component(label1_mask, False,
                                                    wt_threshold)
            label1 = pred1 * label1_mask

            label2_3_mask = (label2 + label3) > 0
            label2_3_mask = label2_3_mask * label1_mask
            label2_3_mask = ndimage.morphology.binary_closing(label2_3_mask,
                                                              structure=struct)
            label2_3_mask = remove_external_core(label1, label2_3_mask)
            if (label2_3_mask.sum() > 0):
                label2_3_mask = get_largest_two_component(label2_3_mask)
            label1 = (label1 + label2_3_mask) > 0
            label2 = label2_3_mask
            label3 = label2 * label3
            vox_3 = np.asarray(label3 > 0, np.float32).sum()
            if (0 < vox_3 and vox_3 < 30):
                label3 = np.zeros_like(label2)

            # 5.5, convert label and save output
            out_label = label1 * 2
            if ('Flair' in config_data['modality_postfix']
                    and 'mha' in config_data['file_postfix']):
                out_label[label2 > 0] = 3
                out_label[label3 == 1] = 1
                out_label[label3 == 2] = 4
            elif (('flair' in config_data['modality_postfix']
                   or 'FLAIR' in config_data['modality_postfix'])
                  and 'nii' in config_data['file_postfix']):
                out_label[label2 > 0] = 1
                out_label[label3 > 0] = 4
            out_label = np.asarray(out_label, np.int16)

        test_time.append(time.time() - t0)
        final_label = np.zeros(temp_size, np.int16)
        final_label = set_ND_volume_roi_with_bounding_box_range(
            final_label, temp_bbox[0], temp_bbox[1], out_label)
        save_array_as_nifty_volume(
            final_label, './' + save_folder + "/{0:}.nii.gz".format(temp_name),
            img_names[0])
        print(temp_name)
    test_time = np.asarray(test_time)
    print('test time', test_time.mean())
    np.savetxt(save_folder + '/test_time.txt', test_time)
    sess.close()
Beispiel #17
0
                        help='discriminator steps')
    parser.add_argument('--c_gp_x',
                        type=float,
                        default=10.,
                        help='coefficient for gradient penalty x')
    parser.add_argument('--lamda',
                        type=float,
                        default=.1,
                        help='coefficient for divergence of z')
    parser.add_argument('--output_path',
                        type=str,
                        default='./',
                        help='output path')
    parser.add_argument('-config')
    args = parser.parse_args()
    config = parse_config(args.config)
    config_data = config['data']

    print("Loading data...")
    # dataset iterator
    dataloader = DataLoader(config_data)
    dataloader.load_data()
    batch_size = config_data['batch_size']
    full_data_shape = [batch_size] + config_data['data_shape']

    #train_gen, dev_gen, test_gen = tflib.mnist.load(args.batch_size, args.batch_size)


    def inf_train_gen():
        while True:
            train_pair = dataloader.get_subimage_batch()
Beispiel #18
0
def test(config_file):
    # 1, load configure file
    config = parse_config(config_file)
    config_data = config['data']
    config_net1 = config.get('network1', None)
    config_net2 = config.get('network2', None)
    config_net3 = config.get('network3', None)
    config_test = config['testing']
    batch_size = config_test.get('batch_size', 5)

    # 2.1, network for whole tumor
    if (config_net1):
        net_type1 = config_net1['net_type']
        net_name1 = config_net1['net_name']
        data_shape1 = config_net1['data_shape']
        label_shape1 = config_net1['label_shape']
        class_num1 = config_net1['class_num']

        # construct graph for 1st network
        full_data_shape1 = [batch_size] + data_shape1
        x1 = tf.placeholder(tf.float32, shape=full_data_shape1)
        net_class1 = NetFactory.create(net_type1)
        net1 = net_class1(num_classes=class_num1,
                          w_regularizer=None,
                          b_regularizer=None,
                          name=net_name1)
        net1.set_params(config_net1)
        predicty1 = net1(x1, is_training=True)
        proby1 = tf.nn.softmax(predicty1)
    else:
        config_net1ax = config['network1ax']
        config_net1sg = config['network1sg']
        config_net1cr = config['network1cr']

        # construct graph for 1st network axial
        net_type1ax = config_net1ax['net_type']
        net_name1ax = config_net1ax['net_name']
        data_shape1ax = config_net1ax['data_shape']
        label_shape1ax = config_net1ax['label_shape']
        class_num1ax = config_net1ax['class_num']

        full_data_shape1ax = [batch_size] + data_shape1ax
        x1ax = tf.placeholder(tf.float32, shape=full_data_shape1ax)
        net_class1ax = NetFactory.create(net_type1ax)
        net1ax = net_class1ax(num_classes=class_num1ax,
                              w_regularizer=None,
                              b_regularizer=None,
                              name=net_name1ax)
        net1ax.set_params(config_net1ax)
        predicty1ax = net1ax(x1ax, is_training=True)
        proby1ax = tf.nn.softmax(predicty1ax)

        # construct graph for 1st network sagittal
        net_type1sg = config_net1sg['net_type']
        net_name1sg = config_net1sg['net_name']
        data_shape1sg = config_net1sg['data_shape']
        label_shape1sg = config_net1sg['label_shape']
        class_num1sg = config_net1sg['class_num']

        full_data_shape1sg = [batch_size] + data_shape1sg
        x1sg = tf.placeholder(tf.float32, shape=full_data_shape1sg)
        net_class1sg = NetFactory.create(net_type1sg)
        net1sg = net_class1sg(num_classes=class_num1sg,
                              w_regularizer=None,
                              b_regularizer=None,
                              name=net_name1sg)
        net1sg.set_params(config_net1sg)
        predicty1sg = net1sg(x1sg, is_training=True)
        proby1sg = tf.nn.softmax(predicty1sg)

        # construct graph for 1st network corogal
        net_type1cr = config_net1cr['net_type']
        net_name1cr = config_net1cr['net_name']
        data_shape1cr = config_net1cr['data_shape']
        label_shape1cr = config_net1cr['label_shape']
        class_num1cr = config_net1cr['class_num']

        full_data_shape1cr = [batch_size] + data_shape1cr
        x1cr = tf.placeholder(tf.float32, shape=full_data_shape1cr)
        net_class1cr = NetFactory.create(net_type1cr)
        net1cr = net_class1cr(num_classes=class_num1cr,
                              w_regularizer=None,
                              b_regularizer=None,
                              name=net_name1cr)
        net1cr.set_params(config_net1cr)
        predicty1cr = net1cr(x1cr, is_training=True)
        proby1cr = tf.nn.softmax(predicty1cr)

    # 3, create session and load trained models
    print('create session and load trained models /n')
    model_t0 = time.time()

    # with tf.device("/device:GPU:0"): #0806
    all_vars = tf.global_variables()
    sess = tf.InteractiveSession()
    sess.run(tf.global_variables_initializer())
    if (config_net1):
        net1_vars = [
            x for x in all_vars
            if x.name[0:len(net_name1) + 1] == net_name1 + '/'
        ]
        saver1 = tf.train.Saver(net1_vars)
        saver1.restore(sess, config_net1['model_file'])
    else:
        net1ax_vars = [
            x for x in all_vars
            if x.name[0:len(net_name1ax) + 1] == net_name1ax + '/'
        ]
        saver1ax = tf.train.Saver(net1ax_vars)
        saver1ax.restore(sess, config_net1ax['model_file'])
        net1sg_vars = [
            x for x in all_vars
            if x.name[0:len(net_name1sg) + 1] == net_name1sg + '/'
        ]
        saver1sg = tf.train.Saver(net1sg_vars)
        saver1sg.restore(sess, config_net1sg['model_file'])
        net1cr_vars = [
            x for x in all_vars
            if x.name[0:len(net_name1cr) + 1] == net_name1cr + '/'
        ]
        saver1cr = tf.train.Saver(net1cr_vars)
        saver1cr.restore(sess, config_net1cr['model_file'])

    print('Model load time is {}'.format(time.time() - model_t0))

    # 4, load test images
    print('load test images \n')
    load_t0 = time.time()
    dataloader = DataLoader(config_data)
    dataloader.load_data()
    image_num = dataloader.get_total_image_number()
    print('data load time is {}'.format(time.time() - load_t0))

    # 5, start to test
    print('start to test \n')
    test_slice_direction = config_test.get('test_slice_direction', 'all')
    # save_folder = config_data['save_folder']
    #Ben:save the segment output to the same folder as the input
    save_folder = '.'
    test_time = []
    struct = ndimage.generate_binary_structure(3, 2)
    margin = config_test.get('roi_patch_margin', 5)

    for i in range(image_num):
        [temp_imgs, temp_weight, temp_name, img_names, temp_bbox,
         temp_size] = dataloader.get_image_data_with_name(i)
        print(f'Segmenting on case {temp_name}\n')
        t0 = time.time()
        # 5.1, test of 1st network
        if (config_net1):
            data_shapes = [
                data_shape1[:-1], data_shape1[:-1], data_shape1[:-1]
            ]
            label_shapes = [
                label_shape1[:-1], label_shape1[:-1], label_shape1[:-1]
            ]
            nets = [net1, net1, net1]
            outputs = [proby1, proby1, proby1]
            inputs = [x1, x1, x1]
            class_num = class_num1
        else:
            data_shapes = [
                data_shape1ax[:-1], data_shape1sg[:-1], data_shape1cr[:-1]
            ]
            label_shapes = [
                label_shape1ax[:-1], label_shape1sg[:-1], label_shape1cr[:-1]
            ]
            nets = [net1ax, net1sg, net1cr]
            outputs = [proby1ax, proby1sg, proby1cr]
            inputs = [x1ax, x1sg, x1cr]
            class_num = class_num1ax
        prob1 = test_one_image_three_nets_adaptive_shape(
            temp_imgs,
            data_shapes,
            label_shapes,
            data_shape1ax[-1],
            class_num,
            batch_size,
            sess,
            nets,
            outputs,
            inputs,
            shape_mode=2)  #average probability of ax,sg,co
        pred1 = np.asarray(np.argmax(prob1, axis=3), np.uint16)
        pred1 = pred1 * temp_weight  #what is the temp_weight

        wt_threshold = 2000
        pred1_lc = ndimage.morphology.binary_closing(pred1, structure=struct)
        pred1_lc = get_largest_two_component(pred1_lc, False, wt_threshold)
        out_label = pred1_lc

        test_time.append(time.time() - t0)
        final_label = np.zeros(temp_size, np.int16)
        final_label = set_ND_volume_roi_with_bounding_box_range(
            final_label, temp_bbox[0], temp_bbox[1], out_label)
        #Todo check save path's existence, if not, mkdir
        subfolder = f'{save_folder}/{temp_name}'
        if not os.path.exists(subfolder):
            os.makedirs(subfolder)
        save_array_as_nifty_volume(
            final_label,
            subfolder + "/{}_brain.nii.gz".format(temp_name.split('/')[-1]),
            img_names[0])

    test_time = np.asarray(test_time)
    print('test time', test_time.mean())
    np.savetxt(save_folder + '/test_time.txt', test_time)
    sess.close()
Beispiel #19
0
def train(config_file):
    train_mode = tf.estimator.ModeKeys.TRAIN
    # 1, load configuration parameters
    config = parse_config(config_file)
    config_data  = config['data']
    config_net   = config['network']
    config_train = config['training']
    random.seed(config_train.get('random_seed', 1))

    # 2, setup data generator
    
    # create pre_processor and sampler
    pre_processor = get_brats_preprocess_layers(config_data, mode = train_mode)
    sampler = RandomSamplerWithCrop()
    sample_num_per_image = config_data.get('sample_num_per_image', 5)
    sample_shape = config_data['data_shape']
    sampler.set_sample_patch(sample_num_per_image, sample_shape)

    # create tensorflow dataset, iterator and initializer
    tr_data = DataSetFromTFRecord(config_data, mode = train_mode,
                                 preprocess_layers = pre_processor,
                                 sampler           = sampler)
    iterator = tf.data.Iterator.from_structure(tr_data.data.output_types,
                                       tr_data.data.output_shapes)
    next_element = iterator.get_next()
    training_init_op = iterator.make_initializer(tr_data.data)


    # 2, construct graph
    net_type    = config_net['net_type']
    net_name    = config_net['net_name']
    class_num   = config_net['class_num']
    batch_size  = config_data['batch_size']
    full_data_shape  = [batch_size] + config_data['data_shape'] + \
                       [config_data['feature_channel_num']]
    full_label_shape = [batch_size] + config_data['label_shape'] + \
                       [config_data['prediction_channel_num']]
    x = tf.placeholder(tf.float32, shape = full_data_shape)
    w = tf.placeholder(tf.float32, shape = full_label_shape)
    y = tf.placeholder(tf.int32,   shape = full_label_shape)
   
    w_regularizer = regularizers.l2_regularizer(config_train.get('decay', 1e-7))
    b_regularizer = regularizers.l2_regularizer(config_train.get('decay', 1e-7))
    net_class = NetFactory.create(net_type)
    net = net_class(num_classes = class_num,
                    w_regularizer = w_regularizer,
                    b_regularizer = b_regularizer,
                    name = net_name)
    net.set_params(config_net)
    predicty = net(x, is_training = True)
    proby    = tf.nn.softmax(predicty)
    
#     loss_func = LossFunction(n_class=class_num)
#     loss = loss_func(predicty, y, weight_map = w)
#     print('size of predicty:',predicty)
    y_soft  = get_soft_label(y, class_num)
    loss = soft_dice_loss(predicty, y_soft, weight_map = w)
    
    # 3, initialize session and saver
    lr = config_train.get('learning_rate', 1e-3)
    opt_step = tf.train.AdamOptimizer(lr).minimize(loss)
    sess = tf.InteractiveSession()   
    sess.run(tf.global_variables_initializer())
    sess.run(training_init_op)
    saver = tf.train.Saver()

    # 4, start to train
    loss_file = config_train['model_save_prefix'] + "_loss.txt"
    start_it  = config_train.get('start_iteration', 0)
    if( start_it > 0):
        saver.restore(sess, config_train['model_pre_trained'])
    loss_list, temp_loss_list = [], []

    margin = int((config_data['data_shape'][0]  - config_data['label_shape'][0])/2)
    for n in range(start_it, config_train['maximal_iteration']):
        try:
            elem = sess.run(next_element)
            name = elem['entry_name']
            tempx = elem['entry_data']['feature']
            tempy = elem['entry_data']['prediction']
            tempw = elem['entry_data']['mask']
            if(margin > 0):
                tempy = tempy[:, margin:-margin, :, :, :]
                tempw = tempw[:, margin:-margin, :, :, :]
            opt_step.run(session = sess, feed_dict={x:tempx, w: tempw, y:tempy})
    
            if(n%config_train['test_iteration'] == 0):
                batch_dice_list = []
                for step in range(config_train['test_step']):
                    elem = sess.run(next_element)
                    tempx = elem['entry_data']['feature']
                    tempy = elem['entry_data']['prediction']
                    tempw = elem['entry_data']['mask']
                    if(margin > 0):
                        tempy = tempy[:, margin:-margin, :, :, :]
                        tempw = tempw[:, margin:-margin, :, :, :]
                    dice = loss.eval(feed_dict ={x:tempx, w:tempw, y:tempy})
                    batch_dice_list.append(dice)
                batch_dice = np.asarray(batch_dice_list, np.float32).mean()
                t = time.strftime('%X %x %Z')
                print(t, 'n', n,'loss', batch_dice)
                loss_list.append(batch_dice)
                np.savetxt(loss_file, np.asarray(loss_list))
    
            if((n+1)%config_train['snapshot_iteration']  == 0):
                saver.save(sess, config_train['model_save_prefix']+"_{0:}.ckpt".format(n+1))
        except tf.errors.OutOfRangeError:
            sess.run(training_init_op)
    sess.close()