コード例 #1
0
def train():
    NUM_EPOCHS = options['num_epochs']
    LOAD_PATH = options['load_path']
    SAVE_PATH = options['save_path']
    PSIZE = options['psize']
    HSIZE = options['hsize']
    WSIZE = options['wsize']
    CSIZE = options['csize']
    model_name= options['model_name']
    BATCH_SIZE = options['batch_size']
    continue_training = options['continue_training']

    files = []
    num_labels = 5
    with open('train.txt') as f:
        for line in f:
            files.append(line[:-1])
    print("%d training samples" % len(files))

    flair_t2_node = tf.placeholder(dtype=tf.float32, shape=(None, HSIZE, WSIZE, CSIZE, 2))
    t1_t1ce_node = tf.placeholder(dtype=tf.float32, shape=(None, HSIZE, WSIZE, CSIZE, 2))
    flair_t2_gt_node = tf.placeholder(dtype=tf.int32, shape=(None, PSIZE, PSIZE, PSIZE, 2))
    t1_t1ce_gt_node = tf.placeholder(dtype=tf.int32, shape=(None, PSIZE, PSIZE, PSIZE, 5))

    if model_name == 'dense48':
        flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat_large(input=flair_t2_node, name='flair')
        t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat_large(input=t1_t1ce_node, name='t1')
    elif model_name == 'no_dense':

        flair_t2_15, flair_t2_27 = tf_models.PlainCounterpart(input=flair_t2_node, name='flair')
        t1_t1ce_15, t1_t1ce_27 = tf_models.PlainCounterpart(input=t1_t1ce_node, name='t1')

    elif model_name == 'dense24':

        flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat(input=flair_t2_node, name='flair')
        t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat(input=t1_t1ce_node, name='t1')
    else:
        print("No such model name")

    t1_t1ce_15 = concatenate([t1_t1ce_15, flair_t2_15])
    t1_t1ce_27 = concatenate([t1_t1ce_27, flair_t2_27])

    flair_t2_15 = Conv3D(2, kernel_size=1, strides=1, padding='same', name='flair_t2_15_cls')(flair_t2_15)
    flair_t2_27 = Conv3D(2, kernel_size=1, strides=1, padding='same', name='flair_t2_27_cls')(flair_t2_27)
    t1_t1ce_15 = Conv3D(num_labels, kernel_size=1, strides=1, padding='same', name='t1_t1ce_15_cls')(t1_t1ce_15)
    t1_t1ce_27 = Conv3D(num_labels, kernel_size=1, strides=1, padding='same', name='t1_t1ce_27_cls')(t1_t1ce_27)

    flair_t2_score = flair_t2_15[:, 13:25, 13:25, 13:25, :] + \
                     flair_t2_27[:, 13:25, 13:25, 13:25, :]

    t1_t1ce_score = t1_t1ce_15[:, 13:25, 13:25, 13:25, :] + \
                    t1_t1ce_27[:, 13:25, 13:25, 13:25, :]

    loss = segmentation_loss(flair_t2_gt_node, flair_t2_score, 2) + \
           segmentation_loss(t1_t1ce_gt_node, t1_t1ce_score, 5)

    acc_flair_t2 = acc_tf(y_pred=flair_t2_score, y_true=flair_t2_gt_node)
    acc_t1_t1ce = acc_tf(y_pred=t1_t1ce_score, y_true=t1_t1ce_gt_node)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        optimizer = tf.train.AdamOptimizer(learning_rate=5e-4).minimize(loss)

    saver = tf.train.Saver(max_to_keep=15)
    data_gen_train = vox_generator(all_files=files, n_pos=200, n_neg=200,correction = options['correction'])

    with tf.Session() as sess:
        if continue_training:
            saver.restore(sess, LOAD_PATH)
        else:
            sess.run(tf.global_variables_initializer())
        for ei in range(NUM_EPOCHS):
            for pi in range(len(files)):
                acc_pi, loss_pi = [], []
                data, labels, centers = data_gen_train.next()
                n_batches = int(np.ceil(float(centers.shape[1]) / BATCH_SIZE))
                for nb in range(n_batches):
                    offset_batch = min(nb * BATCH_SIZE, centers.shape[1] - BATCH_SIZE)
                    data_batch, label_batch = get_patches_3d(data, labels, centers[:, offset_batch:offset_batch + BATCH_SIZE], HSIZE, WSIZE, CSIZE, PSIZE, False)
                    label_batch = label_transform(label_batch, 5)
                    _, l, acc_ft, acc_t1c = sess.run(fetches=[optimizer, loss, acc_flair_t2, acc_t1_t1ce],
                                                   feed_dict={flair_t2_node: data_batch[:, :, :, :, :2],
                                                              t1_t1ce_node: data_batch[:, :, :, :, 2:],
                                                              flair_t2_gt_node: label_batch[0],
                                                              t1_t1ce_gt_node: label_batch[1],
                                                              learning_phase(): 1})
                    acc_pi.append([acc_ft, acc_t1c])
                    loss_pi.append(l)
                    n_pos_sum = np.sum(np.reshape(label_batch[0], (-1, 2)), axis=0)
                    print('epoch-patient: %d, %d, iter: %d-%d, p%%: %.4f, loss: %.4f, acc_flair_t2: %.2f%%, acc_t1_t1ce: %.2f%%' % \
                          (ei + 1, pi + 1, nb + 1, n_batches, n_pos_sum[1]/float(np.sum(n_pos_sum)), l, acc_ft, acc_t1c))

                print('patient loss: %.4f, patient acc: %.4f' % (np.mean(loss_pi), np.mean(acc_pi)))

            saver.save(sess, SAVE_PATH, global_step=ei)
            print('model saved')

    if __name__ == '__main__':

        train()
コード例 #2
0
def main():
    test_files = []
    with open("test_hgg.txt") as f:
        for line in f:
            test_files.append(line[:-2])

    num_labels = 5
    OFFSET_H = options["offset_h"]
    OFFSET_W = options["offset_w"]
    OFFSET_C = options["offset_c"]
    HSIZE = options["hsize"]
    WSIZE = options["wsize"]
    CSIZE = options["csize"]
    PSIZE = options["psize"]
    SAVE_PATH = options["model_path"]
    model_name = options["model_name"]

    OFFSET_PH = (HSIZE - PSIZE) / 2
    OFFSET_PW = (WSIZE - PSIZE) / 2
    OFFSET_PC = (CSIZE - PSIZE) / 2

    batches_w = int(np.ceil((240 - WSIZE) / float(OFFSET_W))) + 1
    batches_h = int(np.ceil((240 - HSIZE) / float(OFFSET_H))) + 1
    batches_c = int(np.ceil((155 - CSIZE) / float(OFFSET_C))) + 1

    flair_t2_node = tf.placeholder(
        dtype=tf.float32, shape=(None, HSIZE, WSIZE, CSIZE, 2)
    )
    t1_t1ce_node = tf.placeholder(
        dtype=tf.float32, shape=(None, HSIZE, WSIZE, CSIZE, 2)
    )

    if model_name == "dense48":

        flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat_large(
            input=flair_t2_node, name="flair"
        )
        t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat_large(
            input=t1_t1ce_node, name="t1"
        )
    elif model_name == "no_dense":

        flair_t2_15, flair_t2_27 = tf_models.PlainCounterpart(
            input=flair_t2_node, name="flair"
        )
        t1_t1ce_15, t1_t1ce_27 = tf_models.PlainCounterpart(
            input=t1_t1ce_node, name="t1"
        )

    elif model_name == "dense24":

        flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat(
            input=flair_t2_node, name="flair"
        )
        t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat(
            input=t1_t1ce_node, name="t1"
        )

    elif model_name == "dense24_nocorrection":

        flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat(
            input=flair_t2_node, name="flair"
        )
        t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat(
            input=t1_t1ce_node, name="t1"
        )

    else:
        print(" No such model name ")

    t1_t1ce_15 = concatenate([t1_t1ce_15, flair_t2_15])
    t1_t1ce_27 = concatenate([t1_t1ce_27, flair_t2_27])

    t1_t1ce_15 = Conv3D(
        num_labels,
        kernel_size=1,
        strides=1,
        padding="same",
        name="t1_t1ce_15_cls",
    )(t1_t1ce_15)
    t1_t1ce_27 = Conv3D(
        num_labels,
        kernel_size=1,
        strides=1,
        padding="same",
        name="t1_t1ce_27_cls",
    )(t1_t1ce_27)

    t1_t1ce_score = (
        t1_t1ce_15[:, 13:25, 13:25, 13:25, :]
        + t1_t1ce_27[:, 13:25, 13:25, 13:25, :]
    )

    saver = tf.train.Saver()
    data_gen_test = vox_generator_test(test_files)
    dice_whole, dice_core, dice_et = [], [], []
    with tf.Session() as sess:
        saver.restore(sess, SAVE_PATH)
        for i in range(len(test_files)):
            print("predicting %s" % test_files[i])
            x, x_n, y = next(data_gen_test)
            pred = np.zeros([240, 240, 155, 5])
            for hi in range(batches_h):
                offset_h = min(OFFSET_H * hi, 240 - HSIZE)
                offset_ph = int(offset_h + OFFSET_PH)
                for wi in range(batches_w):
                    offset_w = min(OFFSET_W * wi, 240 - WSIZE)
                    offset_pw = int(offset_w + OFFSET_PW)
                    for ci in range(batches_c):
                        offset_c = min(OFFSET_C * ci, 155 - CSIZE)
                        offset_pc = int(offset_c + OFFSET_PC)
                        data = x[
                            offset_h : offset_h + HSIZE,
                            offset_w : offset_w + WSIZE,
                            offset_c : offset_c + CSIZE,
                            :,
                        ]
                        data_norm = x_n[
                            offset_h : offset_h + HSIZE,
                            offset_w : offset_w + WSIZE,
                            offset_c : offset_c + CSIZE,
                            :,
                        ]
                        data_norm = np.expand_dims(data_norm, 0)
                        if not np.max(data) == 0 and np.min(data) == 0:
                            score = sess.run(
                                fetches=t1_t1ce_score,
                                feed_dict={
                                    flair_t2_node: data_norm[:, :, :, :, :2],
                                    t1_t1ce_node: data_norm[:, :, :, :, 2:],
                                    learning_phase(): 0,
                                },
                            )
                            pred[
                                offset_ph : offset_ph + PSIZE,
                                offset_pw : offset_pw + PSIZE,
                                offset_pc : offset_pc + PSIZE,
                                :,
                            ] += np.squeeze(score)
            pred = np.argmax(pred, axis=-1)
            pred = pred.astype(int)
            print("calculating dice...")
            whole_pred = (pred > 0).astype(int)
            whole_gt = (y > 0).astype(int)
            core_pred = (pred == 1).astype(int) + (pred == 4).astype(int)
            core_gt = (y == 1).astype(int) + (y == 4).astype(int)
            et_pred = (pred == 4).astype(int)
            et_gt = (y == 4).astype(int)
            dice_whole_batch = dice_coef_np(whole_gt, whole_pred, 2)
            dice_core_batch = dice_coef_np(core_gt, core_pred, 2)
            try:
                dice_et_batch = dice_coef_np(et_gt, et_pred, 2)
            except ValueError:
                print("Skipped.")
                continue
            dice_whole.append(dice_whole_batch)
            dice_core.append(dice_core_batch)
            dice_et.append(dice_et_batch)
            print(dice_whole_batch)
            print(dice_core_batch)
            print(dice_et_batch)

        dice_whole = np.array(dice_whole)
        dice_core = np.array(dice_core)
        dice_et = np.array(dice_et)

        print("mean dice whole:")
        print(np.mean(dice_whole, axis=0))
        print("mean dice core:")
        print(np.mean(dice_core, axis=0))
        print("mean dice enhance:")
        print(np.mean(dice_et, axis=0))

        np.save(model_name + "_dice_whole", dice_whole)
        np.save(model_name + "_dice_core", dice_core)
        np.save(model_name + "_dice_enhance", dice_et)
        print("pred saved")
コード例 #3
0
ファイル: test.py プロジェクト: lelechen63/spie
def main():
    test_files = []
    with open('test.txt') as f:
        for line in f:
            test_files.append(line[:-1])

    num_labels = 5
    OFFSET_H = options['offset_h']
    OFFSET_W = options['offset_w']
    OFFSET_C = options['offset_c']
    HSIZE = options['hsize']
    WSIZE = options['wsize']
    CSIZE = options['csize']
    PSIZE = options['psize']
    SAVE_PATH = options['model_path']
    model_name = options['model_name']

    OFFSET_PH = (HSIZE - PSIZE) / 2
    OFFSET_PW = (WSIZE - PSIZE) / 2
    OFFSET_PC = (CSIZE - PSIZE) / 2

    batches_w = int(np.ceil((240 - WSIZE) / float(OFFSET_W))) + 1
    batches_h = int(np.ceil((240 - HSIZE) / float(OFFSET_H))) + 1
    batches_c = int(np.ceil((155 - CSIZE) / float(OFFSET_C))) + 1

    flair_t2_node = tf.placeholder(dtype=tf.float32,
                                   shape=(None, HSIZE, WSIZE, CSIZE, 2))
    t1_t1ce_node = tf.placeholder(dtype=tf.float32,
                                  shape=(None, HSIZE, WSIZE, CSIZE, 2))

    if model_name == 'dense48':

        flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat_large(
            input=flair_t2_node, name='flair')
        t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat_large(
            input=t1_t1ce_node, name='t1')
    elif model_name == 'no_dense':

        flair_t2_15, flair_t2_27 = tf_models.PlainCounterpart(
            input=flair_t2_node, name='flair')
        t1_t1ce_15, t1_t1ce_27 = tf_models.PlainCounterpart(input=t1_t1ce_node,
                                                            name='t1')

    elif model_name == 'dense24':

        flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat(
            input=flair_t2_node, name='flair')
        t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat(
            input=t1_t1ce_node, name='t1')

    elif model_name == 'dense24_nocorrection':

        flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat(
            input=flair_t2_node, name='flair')
        t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat(
            input=t1_t1ce_node, name='t1')

    else:
        print ' No such model name '

    t1_t1ce_15 = concatenate([t1_t1ce_15, flair_t2_15])
    t1_t1ce_27 = concatenate([t1_t1ce_27, flair_t2_27])

    t1_t1ce_15 = Conv3D(num_labels,
                        kernel_size=1,
                        strides=1,
                        padding='same',
                        name='t1_t1ce_15_cls')(t1_t1ce_15)
    t1_t1ce_27 = Conv3D(num_labels,
                        kernel_size=1,
                        strides=1,
                        padding='same',
                        name='t1_t1ce_27_cls')(t1_t1ce_27)

    t1_t1ce_score = t1_t1ce_15[:, 13:25, 13:25, 13:25, :] + \
                    t1_t1ce_27[:, 13:25, 13:25, 13:25, :]

    saver = tf.train.Saver()
    data_gen_test = vox_generator_test(test_files)
    dice_whole, dice_core, dice_et = [], [], []
    with tf.Session() as sess:
        saver.restore(sess, SAVE_PATH)
        for i in range(len(test_files) - 1, 0, -1):
            print i
            print 'predicting %s' % test_files[i]
            x, x_n, y = gen_test_data(test_files[i])
            pred = np.zeros([240, 240, 155, 5])
            for hi in range(batches_h):
                offset_h = min(OFFSET_H * hi, 240 - HSIZE)
                offset_ph = offset_h + OFFSET_PH
                for wi in range(batches_w):
                    offset_w = min(OFFSET_W * wi, 240 - WSIZE)
                    offset_pw = offset_w + OFFSET_PW
                    for ci in range(batches_c):
                        offset_c = min(OFFSET_C * ci, 155 - CSIZE)
                        offset_pc = offset_c + OFFSET_PC
                        data = x[offset_h:offset_h + HSIZE,
                                 offset_w:offset_w + WSIZE,
                                 offset_c:offset_c + CSIZE, :]
                        data_norm = x_n[offset_h:offset_h + HSIZE,
                                        offset_w:offset_w + WSIZE,
                                        offset_c:offset_c + CSIZE, :]
                        data_norm = np.expand_dims(data_norm, 0)
                        if not np.max(data) == 0 and np.min(data) == 0:
                            score = sess.run(fetches=t1_t1ce_score,
                                             feed_dict={
                                                 flair_t2_node:
                                                 data_norm[:, :, :, :, :2],
                                                 t1_t1ce_node:
                                                 data_norm[:, :, :, :, 2:],
                                                 learning_phase():
                                                 0
                                             })
                            pred[offset_ph:offset_ph + PSIZE,
                                 offset_pw:offset_pw + PSIZE,
                                 offset_pc:offset_pc +
                                 PSIZE, :] += np.squeeze(score)

            pred = np.argmax(pred, axis=-1)
            pred = pred.astype(int)
            print 'calculating dice...'
            print options['save_path'] + test_files[i] + '_prediction'
            np.save(options['save_path'] + test_files[i] + '_prediction', pred)
            whole_pred = (pred > 0).astype(int)
            whole_gt = (y > 0).astype(int)
            core_pred = (pred == 1).astype(int) + (pred == 4).astype(int)
            core_gt = (y == 1).astype(int) + (y == 4).astype(int)
            et_pred = (pred == 4).astype(int)
            et_gt = (y == 4).astype(int)
            dice_whole_batch = dice_coef_np(whole_gt, whole_pred, 2)
            dice_core_batch = dice_coef_np(core_gt, core_pred, 2)
            dice_et_batch = dice_coef_np(et_gt, et_pred, 2)
            dice_whole.append(dice_whole_batch)
            dice_core.append(dice_core_batch)
            dice_et.append(dice_et_batch)
            print dice_whole_batch
            print dice_core_batch
            print dice_et_batch

        dice_whole = np.array(dice_whole)
        dice_core = np.array(dice_core)
        dice_et = np.array(dice_et)

        print 'mean dice whole:'
        print np.mean(dice_whole, axis=0)
        print 'mean dice core:'
        print np.mean(dice_core, axis=0)
        print 'mean dice enhance:'
        print np.mean(dice_et, axis=0)

        np.save(model_name + '_dice_whole', dice_whole)
        np.save(model_name + '_dice_core', dice_core)
        np.save(model_name + '_dice_enhance', dice_et)
        print 'pred saved'
コード例 #4
0
def train():
    NUM_EPOCHS = options['num_epochs']
    LOAD_PATH = options['load_path']
    SAVE_PATH = options['save_path']
    PSIZE = options['psize']
    HSIZE = options['hsize']
    WSIZE = options['wsize']
    CSIZE = options['csize']
    model_name = options['model_name']
    BATCH_SIZE = options['batch_size']
    continue_training = options['continue_training']
    lr = tf.Variable(5e-4, trainable=False)

    files = []
    num_labels = 5
    files = get_dataset_dirnames(options['root_path'])
    print '%d training samples' % len(files)

    flair_t2_node = tf.placeholder(dtype=tf.float32,
                                   shape=(None, HSIZE, WSIZE, CSIZE, 2))
    t1_t1ce_node = tf.placeholder(dtype=tf.float32,
                                  shape=(None, HSIZE, WSIZE, CSIZE, 2))
    flair_t2_gt_node = tf.placeholder(dtype=tf.int32,
                                      shape=(None, PSIZE, PSIZE, PSIZE, 2))
    t1_t1ce_gt_node = tf.placeholder(dtype=tf.int32,
                                     shape=(None, PSIZE, PSIZE, PSIZE, 5))

    if model_name == 'dense48':
        flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat_large(
            input=flair_t2_node, name='flair')
        t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat_large(
            input=t1_t1ce_node, name='t1')
    elif model_name == 'no_dense':

        flair_t2_15, flair_t2_27 = tf_models.PlainCounterpart(
            input=flair_t2_node, name='flair')
        t1_t1ce_15, t1_t1ce_27 = tf_models.PlainCounterpart(input=t1_t1ce_node,
                                                            name='t1')

    elif model_name == 'dense24':

        flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat(
            input=flair_t2_node, name='flair')
        t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat(
            input=t1_t1ce_node, name='t1')
    else:
        print ' No such model name '

    t1_t1ce_15 = concatenate([t1_t1ce_15, flair_t2_15])
    t1_t1ce_27 = concatenate([t1_t1ce_27, flair_t2_27])

    flair_t2_15 = Conv3D(2,
                         kernel_size=1,
                         strides=1,
                         padding='same',
                         name='flair_t2_15_cls')(flair_t2_15)
    flair_t2_27 = Conv3D(2,
                         kernel_size=1,
                         strides=1,
                         padding='same',
                         name='flair_t2_27_cls')(flair_t2_27)
    t1_t1ce_15 = Conv3D(num_labels,
                        kernel_size=1,
                        strides=1,
                        padding='same',
                        name='t1_t1ce_15_cls')(t1_t1ce_15)
    t1_t1ce_27 = Conv3D(num_labels,
                        kernel_size=1,
                        strides=1,
                        padding='same',
                        name='t1_t1ce_27_cls')(t1_t1ce_27)

    flair_t2_score = flair_t2_15[:, 13:25, 13:25, 13:25, :] + \
                     flair_t2_27[:, 13:25, 13:25, 13:25, :]

    t1_t1ce_score = t1_t1ce_15[:, 13:25, 13:25, 13:25, :] + \
                    t1_t1ce_27[:, 13:25, 13:25, 13:25, :]

    loss = segmentation_loss(flair_t2_gt_node, flair_t2_score, 2) + \
           segmentation_loss(t1_t1ce_gt_node, t1_t1ce_score, 5)

    acc_flair_t2 = acc_tf(y_pred=flair_t2_score, y_true=flair_t2_gt_node)
    acc_t1_t1ce = acc_tf(y_pred=t1_t1ce_score, y_true=t1_t1ce_gt_node)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        optimizer = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss)

    saver = tf.train.Saver(max_to_keep=15)
    data_gen_train = vox_generator(all_files=files, n_pos=200, n_neg=200)

    def single_gpu_fn(nb, gpuname='/device:GPU:0', q=None):  # q - result queue
        with tf.device(gpuname):
            offset_batch = min(nb * BATCH_SIZE, centers.shape[1] - BATCH_SIZE)
            data_batch, label_batch = get_patches_3d(
                data, labels, centers[:,
                                      offset_batch:offset_batch + BATCH_SIZE],
                HSIZE, WSIZE, CSIZE, PSIZE, False)
            label_batch = label_transform(label_batch, 5)
            _, l, acc_ft, acc_t1c = sess.run(
                fetches=[optimizer, loss, acc_flair_t2, acc_t1_t1ce],
                feed_dict={
                    flair_t2_node: data_batch[:, :, :, :, :2],
                    t1_t1ce_node: data_batch[:, :, :, :, 2:],
                    flair_t2_gt_node: label_batch[0],
                    t1_t1ce_gt_node: label_batch[1],
                })
            n_pos_sum = np.sum(np.reshape(label_batch[0], (-1, 2)), axis=0)

        return acc_ft, acc_t1c, l, n_pos_sum

    if not os.path.isdir('chkpts'):
        os.mkdir('chkpts')
        os.mkdir('chkpts/0')
        save_point = 0
    else:
        save_point = sorted(
            [int(x.split('/')[-1]) for x in glob.glob('chkpts/*')])[-1] + 1
        os.mkdir('chkpts/%d' % save_point)

    with tf.Session() as sess:
        if continue_training:
            saver.restore(sess, LOAD_PATH)
        else:
            sess.run(tf.global_variables_initializer())
        for ei in range(NUM_EPOCHS):
            for pi in range(len(files)):
                acc_pi, loss_pi = [], []
                data, labels, centers = data_gen_train.next()
                n_batches = int(np.ceil(float(centers.shape[1]) / BATCH_SIZE))
                threads = []

                for nb in range(0, n_batches, len(options['gpu_ids'])):
                    for gi, x in enumerate(options['gpu_ids']):

                        #t = time.time()

                        acc_ft, acc_t1c, l, n_pos_sum = single_gpu_fn(nb + gi)
                        acc_pi.append([acc_ft, acc_t1c])
                        loss_pi.append(l)
                        '''
                        q = [Queue.Queue() for _ in range(4)]
                        t = Thread(target=single_gpu_fn, args=(nb+gi,'/device:GPU:%d'%x, q))
                        threads.append(t)
                    
                    for th in threads:
                        th.start()
                    for th in threads:
                        th.join()
                    threads = []
                    
                    queue_avg = lambda x, i: np.average(list(x[i].queue))
                    acc_ft, acc_t1c, l, n_pos_sum = queue_avg(q, 0), queue_avg(q, 1), queue_avg(q, 2), np.mean(list(q[3].queue), axis=0)
                    '''

                    #print ('TIME: %.4f'%(time.time()-t))

                    print 'epoch-patient: %d, %d, iter: %d-%d, p%%: %.4f, loss: %.4f, acc_flair_t2: %.2f%%, acc_t1_t1ce: %.2f%%' % \
                          (ei + 1, pi + 1, nb + 1, n_batches, n_pos_sum[1]/float(np.sum(n_pos_sum)), l, acc_ft, acc_t1c)

                print 'patient loss: %.4f, patient acc: %.4f' % (
                    np.mean(loss_pi), np.mean(acc_pi))

            saver.save(sess,
                       'chkpts/' + str(save_point) + '/' + SAVE_PATH + '.ckpt',
                       global_step=ei)
            print 'model saved'

            lr = tf.train.exponential_decay(lr, ei, 1, 0.25, staircase=True)