def train(): NUM_EPOCHS = options['num_epochs'] LOAD_PATH = options['load_path'] SAVE_PATH = options['save_path'] PSIZE = options['psize'] HSIZE = options['hsize'] WSIZE = options['wsize'] CSIZE = options['csize'] model_name= options['model_name'] BATCH_SIZE = options['batch_size'] continue_training = options['continue_training'] files = [] num_labels = 5 with open('train.txt') as f: for line in f: files.append(line[:-1]) print("%d training samples" % len(files)) flair_t2_node = tf.placeholder(dtype=tf.float32, shape=(None, HSIZE, WSIZE, CSIZE, 2)) t1_t1ce_node = tf.placeholder(dtype=tf.float32, shape=(None, HSIZE, WSIZE, CSIZE, 2)) flair_t2_gt_node = tf.placeholder(dtype=tf.int32, shape=(None, PSIZE, PSIZE, PSIZE, 2)) t1_t1ce_gt_node = tf.placeholder(dtype=tf.int32, shape=(None, PSIZE, PSIZE, PSIZE, 5)) if model_name == 'dense48': flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat_large(input=flair_t2_node, name='flair') t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat_large(input=t1_t1ce_node, name='t1') elif model_name == 'no_dense': flair_t2_15, flair_t2_27 = tf_models.PlainCounterpart(input=flair_t2_node, name='flair') t1_t1ce_15, t1_t1ce_27 = tf_models.PlainCounterpart(input=t1_t1ce_node, name='t1') elif model_name == 'dense24': flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat(input=flair_t2_node, name='flair') t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat(input=t1_t1ce_node, name='t1') else: print("No such model name") t1_t1ce_15 = concatenate([t1_t1ce_15, flair_t2_15]) t1_t1ce_27 = concatenate([t1_t1ce_27, flair_t2_27]) flair_t2_15 = Conv3D(2, kernel_size=1, strides=1, padding='same', name='flair_t2_15_cls')(flair_t2_15) flair_t2_27 = Conv3D(2, kernel_size=1, strides=1, padding='same', name='flair_t2_27_cls')(flair_t2_27) t1_t1ce_15 = Conv3D(num_labels, kernel_size=1, strides=1, padding='same', name='t1_t1ce_15_cls')(t1_t1ce_15) t1_t1ce_27 = Conv3D(num_labels, kernel_size=1, strides=1, padding='same', name='t1_t1ce_27_cls')(t1_t1ce_27) flair_t2_score = flair_t2_15[:, 13:25, 13:25, 13:25, :] + \ flair_t2_27[:, 13:25, 13:25, 13:25, :] t1_t1ce_score = t1_t1ce_15[:, 13:25, 13:25, 13:25, :] + \ t1_t1ce_27[:, 13:25, 13:25, 13:25, :] loss = segmentation_loss(flair_t2_gt_node, flair_t2_score, 2) + \ segmentation_loss(t1_t1ce_gt_node, t1_t1ce_score, 5) acc_flair_t2 = acc_tf(y_pred=flair_t2_score, y_true=flair_t2_gt_node) acc_t1_t1ce = acc_tf(y_pred=t1_t1ce_score, y_true=t1_t1ce_gt_node) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optimizer = tf.train.AdamOptimizer(learning_rate=5e-4).minimize(loss) saver = tf.train.Saver(max_to_keep=15) data_gen_train = vox_generator(all_files=files, n_pos=200, n_neg=200,correction = options['correction']) with tf.Session() as sess: if continue_training: saver.restore(sess, LOAD_PATH) else: sess.run(tf.global_variables_initializer()) for ei in range(NUM_EPOCHS): for pi in range(len(files)): acc_pi, loss_pi = [], [] data, labels, centers = data_gen_train.next() n_batches = int(np.ceil(float(centers.shape[1]) / BATCH_SIZE)) for nb in range(n_batches): offset_batch = min(nb * BATCH_SIZE, centers.shape[1] - BATCH_SIZE) data_batch, label_batch = get_patches_3d(data, labels, centers[:, offset_batch:offset_batch + BATCH_SIZE], HSIZE, WSIZE, CSIZE, PSIZE, False) label_batch = label_transform(label_batch, 5) _, l, acc_ft, acc_t1c = sess.run(fetches=[optimizer, loss, acc_flair_t2, acc_t1_t1ce], feed_dict={flair_t2_node: data_batch[:, :, :, :, :2], t1_t1ce_node: data_batch[:, :, :, :, 2:], flair_t2_gt_node: label_batch[0], t1_t1ce_gt_node: label_batch[1], learning_phase(): 1}) acc_pi.append([acc_ft, acc_t1c]) loss_pi.append(l) n_pos_sum = np.sum(np.reshape(label_batch[0], (-1, 2)), axis=0) print('epoch-patient: %d, %d, iter: %d-%d, p%%: %.4f, loss: %.4f, acc_flair_t2: %.2f%%, acc_t1_t1ce: %.2f%%' % \ (ei + 1, pi + 1, nb + 1, n_batches, n_pos_sum[1]/float(np.sum(n_pos_sum)), l, acc_ft, acc_t1c)) print('patient loss: %.4f, patient acc: %.4f' % (np.mean(loss_pi), np.mean(acc_pi))) saver.save(sess, SAVE_PATH, global_step=ei) print('model saved') if __name__ == '__main__': train()
def main(): test_files = [] with open("test_hgg.txt") as f: for line in f: test_files.append(line[:-2]) num_labels = 5 OFFSET_H = options["offset_h"] OFFSET_W = options["offset_w"] OFFSET_C = options["offset_c"] HSIZE = options["hsize"] WSIZE = options["wsize"] CSIZE = options["csize"] PSIZE = options["psize"] SAVE_PATH = options["model_path"] model_name = options["model_name"] OFFSET_PH = (HSIZE - PSIZE) / 2 OFFSET_PW = (WSIZE - PSIZE) / 2 OFFSET_PC = (CSIZE - PSIZE) / 2 batches_w = int(np.ceil((240 - WSIZE) / float(OFFSET_W))) + 1 batches_h = int(np.ceil((240 - HSIZE) / float(OFFSET_H))) + 1 batches_c = int(np.ceil((155 - CSIZE) / float(OFFSET_C))) + 1 flair_t2_node = tf.placeholder( dtype=tf.float32, shape=(None, HSIZE, WSIZE, CSIZE, 2) ) t1_t1ce_node = tf.placeholder( dtype=tf.float32, shape=(None, HSIZE, WSIZE, CSIZE, 2) ) if model_name == "dense48": flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat_large( input=flair_t2_node, name="flair" ) t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat_large( input=t1_t1ce_node, name="t1" ) elif model_name == "no_dense": flair_t2_15, flair_t2_27 = tf_models.PlainCounterpart( input=flair_t2_node, name="flair" ) t1_t1ce_15, t1_t1ce_27 = tf_models.PlainCounterpart( input=t1_t1ce_node, name="t1" ) elif model_name == "dense24": flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat( input=flair_t2_node, name="flair" ) t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat( input=t1_t1ce_node, name="t1" ) elif model_name == "dense24_nocorrection": flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat( input=flair_t2_node, name="flair" ) t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat( input=t1_t1ce_node, name="t1" ) else: print(" No such model name ") t1_t1ce_15 = concatenate([t1_t1ce_15, flair_t2_15]) t1_t1ce_27 = concatenate([t1_t1ce_27, flair_t2_27]) t1_t1ce_15 = Conv3D( num_labels, kernel_size=1, strides=1, padding="same", name="t1_t1ce_15_cls", )(t1_t1ce_15) t1_t1ce_27 = Conv3D( num_labels, kernel_size=1, strides=1, padding="same", name="t1_t1ce_27_cls", )(t1_t1ce_27) t1_t1ce_score = ( t1_t1ce_15[:, 13:25, 13:25, 13:25, :] + t1_t1ce_27[:, 13:25, 13:25, 13:25, :] ) saver = tf.train.Saver() data_gen_test = vox_generator_test(test_files) dice_whole, dice_core, dice_et = [], [], [] with tf.Session() as sess: saver.restore(sess, SAVE_PATH) for i in range(len(test_files)): print("predicting %s" % test_files[i]) x, x_n, y = next(data_gen_test) pred = np.zeros([240, 240, 155, 5]) for hi in range(batches_h): offset_h = min(OFFSET_H * hi, 240 - HSIZE) offset_ph = int(offset_h + OFFSET_PH) for wi in range(batches_w): offset_w = min(OFFSET_W * wi, 240 - WSIZE) offset_pw = int(offset_w + OFFSET_PW) for ci in range(batches_c): offset_c = min(OFFSET_C * ci, 155 - CSIZE) offset_pc = int(offset_c + OFFSET_PC) data = x[ offset_h : offset_h + HSIZE, offset_w : offset_w + WSIZE, offset_c : offset_c + CSIZE, :, ] data_norm = x_n[ offset_h : offset_h + HSIZE, offset_w : offset_w + WSIZE, offset_c : offset_c + CSIZE, :, ] data_norm = np.expand_dims(data_norm, 0) if not np.max(data) == 0 and np.min(data) == 0: score = sess.run( fetches=t1_t1ce_score, feed_dict={ flair_t2_node: data_norm[:, :, :, :, :2], t1_t1ce_node: data_norm[:, :, :, :, 2:], learning_phase(): 0, }, ) pred[ offset_ph : offset_ph + PSIZE, offset_pw : offset_pw + PSIZE, offset_pc : offset_pc + PSIZE, :, ] += np.squeeze(score) pred = np.argmax(pred, axis=-1) pred = pred.astype(int) print("calculating dice...") whole_pred = (pred > 0).astype(int) whole_gt = (y > 0).astype(int) core_pred = (pred == 1).astype(int) + (pred == 4).astype(int) core_gt = (y == 1).astype(int) + (y == 4).astype(int) et_pred = (pred == 4).astype(int) et_gt = (y == 4).astype(int) dice_whole_batch = dice_coef_np(whole_gt, whole_pred, 2) dice_core_batch = dice_coef_np(core_gt, core_pred, 2) try: dice_et_batch = dice_coef_np(et_gt, et_pred, 2) except ValueError: print("Skipped.") continue dice_whole.append(dice_whole_batch) dice_core.append(dice_core_batch) dice_et.append(dice_et_batch) print(dice_whole_batch) print(dice_core_batch) print(dice_et_batch) dice_whole = np.array(dice_whole) dice_core = np.array(dice_core) dice_et = np.array(dice_et) print("mean dice whole:") print(np.mean(dice_whole, axis=0)) print("mean dice core:") print(np.mean(dice_core, axis=0)) print("mean dice enhance:") print(np.mean(dice_et, axis=0)) np.save(model_name + "_dice_whole", dice_whole) np.save(model_name + "_dice_core", dice_core) np.save(model_name + "_dice_enhance", dice_et) print("pred saved")
def main(): test_files = [] with open('test.txt') as f: for line in f: test_files.append(line[:-1]) num_labels = 5 OFFSET_H = options['offset_h'] OFFSET_W = options['offset_w'] OFFSET_C = options['offset_c'] HSIZE = options['hsize'] WSIZE = options['wsize'] CSIZE = options['csize'] PSIZE = options['psize'] SAVE_PATH = options['model_path'] model_name = options['model_name'] OFFSET_PH = (HSIZE - PSIZE) / 2 OFFSET_PW = (WSIZE - PSIZE) / 2 OFFSET_PC = (CSIZE - PSIZE) / 2 batches_w = int(np.ceil((240 - WSIZE) / float(OFFSET_W))) + 1 batches_h = int(np.ceil((240 - HSIZE) / float(OFFSET_H))) + 1 batches_c = int(np.ceil((155 - CSIZE) / float(OFFSET_C))) + 1 flair_t2_node = tf.placeholder(dtype=tf.float32, shape=(None, HSIZE, WSIZE, CSIZE, 2)) t1_t1ce_node = tf.placeholder(dtype=tf.float32, shape=(None, HSIZE, WSIZE, CSIZE, 2)) if model_name == 'dense48': flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat_large( input=flair_t2_node, name='flair') t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat_large( input=t1_t1ce_node, name='t1') elif model_name == 'no_dense': flair_t2_15, flair_t2_27 = tf_models.PlainCounterpart( input=flair_t2_node, name='flair') t1_t1ce_15, t1_t1ce_27 = tf_models.PlainCounterpart(input=t1_t1ce_node, name='t1') elif model_name == 'dense24': flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat( input=flair_t2_node, name='flair') t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat( input=t1_t1ce_node, name='t1') elif model_name == 'dense24_nocorrection': flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat( input=flair_t2_node, name='flair') t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat( input=t1_t1ce_node, name='t1') else: print ' No such model name ' t1_t1ce_15 = concatenate([t1_t1ce_15, flair_t2_15]) t1_t1ce_27 = concatenate([t1_t1ce_27, flair_t2_27]) t1_t1ce_15 = Conv3D(num_labels, kernel_size=1, strides=1, padding='same', name='t1_t1ce_15_cls')(t1_t1ce_15) t1_t1ce_27 = Conv3D(num_labels, kernel_size=1, strides=1, padding='same', name='t1_t1ce_27_cls')(t1_t1ce_27) t1_t1ce_score = t1_t1ce_15[:, 13:25, 13:25, 13:25, :] + \ t1_t1ce_27[:, 13:25, 13:25, 13:25, :] saver = tf.train.Saver() data_gen_test = vox_generator_test(test_files) dice_whole, dice_core, dice_et = [], [], [] with tf.Session() as sess: saver.restore(sess, SAVE_PATH) for i in range(len(test_files) - 1, 0, -1): print i print 'predicting %s' % test_files[i] x, x_n, y = gen_test_data(test_files[i]) pred = np.zeros([240, 240, 155, 5]) for hi in range(batches_h): offset_h = min(OFFSET_H * hi, 240 - HSIZE) offset_ph = offset_h + OFFSET_PH for wi in range(batches_w): offset_w = min(OFFSET_W * wi, 240 - WSIZE) offset_pw = offset_w + OFFSET_PW for ci in range(batches_c): offset_c = min(OFFSET_C * ci, 155 - CSIZE) offset_pc = offset_c + OFFSET_PC data = x[offset_h:offset_h + HSIZE, offset_w:offset_w + WSIZE, offset_c:offset_c + CSIZE, :] data_norm = x_n[offset_h:offset_h + HSIZE, offset_w:offset_w + WSIZE, offset_c:offset_c + CSIZE, :] data_norm = np.expand_dims(data_norm, 0) if not np.max(data) == 0 and np.min(data) == 0: score = sess.run(fetches=t1_t1ce_score, feed_dict={ flair_t2_node: data_norm[:, :, :, :, :2], t1_t1ce_node: data_norm[:, :, :, :, 2:], learning_phase(): 0 }) pred[offset_ph:offset_ph + PSIZE, offset_pw:offset_pw + PSIZE, offset_pc:offset_pc + PSIZE, :] += np.squeeze(score) pred = np.argmax(pred, axis=-1) pred = pred.astype(int) print 'calculating dice...' print options['save_path'] + test_files[i] + '_prediction' np.save(options['save_path'] + test_files[i] + '_prediction', pred) whole_pred = (pred > 0).astype(int) whole_gt = (y > 0).astype(int) core_pred = (pred == 1).astype(int) + (pred == 4).astype(int) core_gt = (y == 1).astype(int) + (y == 4).astype(int) et_pred = (pred == 4).astype(int) et_gt = (y == 4).astype(int) dice_whole_batch = dice_coef_np(whole_gt, whole_pred, 2) dice_core_batch = dice_coef_np(core_gt, core_pred, 2) dice_et_batch = dice_coef_np(et_gt, et_pred, 2) dice_whole.append(dice_whole_batch) dice_core.append(dice_core_batch) dice_et.append(dice_et_batch) print dice_whole_batch print dice_core_batch print dice_et_batch dice_whole = np.array(dice_whole) dice_core = np.array(dice_core) dice_et = np.array(dice_et) print 'mean dice whole:' print np.mean(dice_whole, axis=0) print 'mean dice core:' print np.mean(dice_core, axis=0) print 'mean dice enhance:' print np.mean(dice_et, axis=0) np.save(model_name + '_dice_whole', dice_whole) np.save(model_name + '_dice_core', dice_core) np.save(model_name + '_dice_enhance', dice_et) print 'pred saved'
def train(): NUM_EPOCHS = options['num_epochs'] LOAD_PATH = options['load_path'] SAVE_PATH = options['save_path'] PSIZE = options['psize'] HSIZE = options['hsize'] WSIZE = options['wsize'] CSIZE = options['csize'] model_name = options['model_name'] BATCH_SIZE = options['batch_size'] continue_training = options['continue_training'] lr = tf.Variable(5e-4, trainable=False) files = [] num_labels = 5 files = get_dataset_dirnames(options['root_path']) print '%d training samples' % len(files) flair_t2_node = tf.placeholder(dtype=tf.float32, shape=(None, HSIZE, WSIZE, CSIZE, 2)) t1_t1ce_node = tf.placeholder(dtype=tf.float32, shape=(None, HSIZE, WSIZE, CSIZE, 2)) flair_t2_gt_node = tf.placeholder(dtype=tf.int32, shape=(None, PSIZE, PSIZE, PSIZE, 2)) t1_t1ce_gt_node = tf.placeholder(dtype=tf.int32, shape=(None, PSIZE, PSIZE, PSIZE, 5)) if model_name == 'dense48': flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat_large( input=flair_t2_node, name='flair') t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat_large( input=t1_t1ce_node, name='t1') elif model_name == 'no_dense': flair_t2_15, flair_t2_27 = tf_models.PlainCounterpart( input=flair_t2_node, name='flair') t1_t1ce_15, t1_t1ce_27 = tf_models.PlainCounterpart(input=t1_t1ce_node, name='t1') elif model_name == 'dense24': flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat( input=flair_t2_node, name='flair') t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat( input=t1_t1ce_node, name='t1') else: print ' No such model name ' t1_t1ce_15 = concatenate([t1_t1ce_15, flair_t2_15]) t1_t1ce_27 = concatenate([t1_t1ce_27, flair_t2_27]) flair_t2_15 = Conv3D(2, kernel_size=1, strides=1, padding='same', name='flair_t2_15_cls')(flair_t2_15) flair_t2_27 = Conv3D(2, kernel_size=1, strides=1, padding='same', name='flair_t2_27_cls')(flair_t2_27) t1_t1ce_15 = Conv3D(num_labels, kernel_size=1, strides=1, padding='same', name='t1_t1ce_15_cls')(t1_t1ce_15) t1_t1ce_27 = Conv3D(num_labels, kernel_size=1, strides=1, padding='same', name='t1_t1ce_27_cls')(t1_t1ce_27) flair_t2_score = flair_t2_15[:, 13:25, 13:25, 13:25, :] + \ flair_t2_27[:, 13:25, 13:25, 13:25, :] t1_t1ce_score = t1_t1ce_15[:, 13:25, 13:25, 13:25, :] + \ t1_t1ce_27[:, 13:25, 13:25, 13:25, :] loss = segmentation_loss(flair_t2_gt_node, flair_t2_score, 2) + \ segmentation_loss(t1_t1ce_gt_node, t1_t1ce_score, 5) acc_flair_t2 = acc_tf(y_pred=flair_t2_score, y_true=flair_t2_gt_node) acc_t1_t1ce = acc_tf(y_pred=t1_t1ce_score, y_true=t1_t1ce_gt_node) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optimizer = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss) saver = tf.train.Saver(max_to_keep=15) data_gen_train = vox_generator(all_files=files, n_pos=200, n_neg=200) def single_gpu_fn(nb, gpuname='/device:GPU:0', q=None): # q - result queue with tf.device(gpuname): offset_batch = min(nb * BATCH_SIZE, centers.shape[1] - BATCH_SIZE) data_batch, label_batch = get_patches_3d( data, labels, centers[:, offset_batch:offset_batch + BATCH_SIZE], HSIZE, WSIZE, CSIZE, PSIZE, False) label_batch = label_transform(label_batch, 5) _, l, acc_ft, acc_t1c = sess.run( fetches=[optimizer, loss, acc_flair_t2, acc_t1_t1ce], feed_dict={ flair_t2_node: data_batch[:, :, :, :, :2], t1_t1ce_node: data_batch[:, :, :, :, 2:], flair_t2_gt_node: label_batch[0], t1_t1ce_gt_node: label_batch[1], }) n_pos_sum = np.sum(np.reshape(label_batch[0], (-1, 2)), axis=0) return acc_ft, acc_t1c, l, n_pos_sum if not os.path.isdir('chkpts'): os.mkdir('chkpts') os.mkdir('chkpts/0') save_point = 0 else: save_point = sorted( [int(x.split('/')[-1]) for x in glob.glob('chkpts/*')])[-1] + 1 os.mkdir('chkpts/%d' % save_point) with tf.Session() as sess: if continue_training: saver.restore(sess, LOAD_PATH) else: sess.run(tf.global_variables_initializer()) for ei in range(NUM_EPOCHS): for pi in range(len(files)): acc_pi, loss_pi = [], [] data, labels, centers = data_gen_train.next() n_batches = int(np.ceil(float(centers.shape[1]) / BATCH_SIZE)) threads = [] for nb in range(0, n_batches, len(options['gpu_ids'])): for gi, x in enumerate(options['gpu_ids']): #t = time.time() acc_ft, acc_t1c, l, n_pos_sum = single_gpu_fn(nb + gi) acc_pi.append([acc_ft, acc_t1c]) loss_pi.append(l) ''' q = [Queue.Queue() for _ in range(4)] t = Thread(target=single_gpu_fn, args=(nb+gi,'/device:GPU:%d'%x, q)) threads.append(t) for th in threads: th.start() for th in threads: th.join() threads = [] queue_avg = lambda x, i: np.average(list(x[i].queue)) acc_ft, acc_t1c, l, n_pos_sum = queue_avg(q, 0), queue_avg(q, 1), queue_avg(q, 2), np.mean(list(q[3].queue), axis=0) ''' #print ('TIME: %.4f'%(time.time()-t)) print 'epoch-patient: %d, %d, iter: %d-%d, p%%: %.4f, loss: %.4f, acc_flair_t2: %.2f%%, acc_t1_t1ce: %.2f%%' % \ (ei + 1, pi + 1, nb + 1, n_batches, n_pos_sum[1]/float(np.sum(n_pos_sum)), l, acc_ft, acc_t1c) print 'patient loss: %.4f, patient acc: %.4f' % ( np.mean(loss_pi), np.mean(acc_pi)) saver.save(sess, 'chkpts/' + str(save_point) + '/' + SAVE_PATH + '.ckpt', global_step=ei) print 'model saved' lr = tf.train.exponential_decay(lr, ei, 1, 0.25, staircase=True)