示例#1
0
 def __init__(self, mode, result_dir, train_dir, testset_dir,
              min_num_workers, cfg):
     """
         args:
             result_dir: Set to None to non-training modes.
     """
     super(PneuSegDF, self).__init__()
     self.cfg = cfg
     self.min_num_workers = min_num_workers
     self.data_dirs = finalize_data_dirs(mode, result_dir, train_dir,
                                         testset_dir, cfg)
     self.ex_process = extra_processing(cfg.im_size, cfg.num_class,
                                        cfg.preprocess, cfg.loss)
     print(f"\nNumber of samples: {len(self)}\n")
示例#2
0
def training_create_ds(args, cfg, num_threads, result_dir):
    data_dirs, val_dirs = finalize_data_dirs(args.mode, result_dir,
                                             args.train_dir, args.testset_dir,
                                             cfg)
    print(f"Num of training samples: {len(data_dirs)}")
    print(f"Num of validation samples: {len(val_dirs)}")
    output_types = (tf.float32, tf.float32, tf.int64, tf.string, tf.float32)
    output_shapes = (
        tf.TensorShape((None, cfg.im_size[1], cfg.im_size[0], 1)),
        tf.TensorShape((None, cfg.im_size[1], cfg.im_size[0], 1)),
        tf.TensorShape((None, 2)),
        tf.TensorShape((None)),
        # Since we are not performing any tf ops on og ims so their size info shouldn't matter.
        # This same character works for others as well as long as we don't care about their size info.
        tf.TensorShape((None)))
    debug = False
    iterator = tf.data.Iterator.from_structure(output_types, output_shapes)  #
    ex_prc = extra_processing(cfg.im_size, cfg.num_class, cfg.preprocess,
                              cfg.loss)

    train_steps_per_epoch = len(data_dirs) // (num_gpus * cfg.batch_size) + 1
    train_ds = tf.data.Dataset.from_generator(lambda: (d for d in data_dirs),
                                              output_types=tf.string)
    train_ds = train_ds.map(
        lambda data_dir: tf.py_func(ex_prc.train_process_tfdata, [data_dir],
                                    output_types), num_threads)
    train_ds = train_ds.repeat().batch(cfg.batch_size).prefetch(512)
    train_init = iterator.make_initializer(train_ds)  #

    val_num_batch = len(val_dirs) // (num_gpus * cfg.batch_size) + 1
    val_ds = tf.data.Dataset.from_generator(lambda: (d for d in val_dirs),
                                            output_types=tf.string)
    val_ds = val_ds.map(
        lambda data_dir: tf.py_func(ex_prc.val_process_tfdata, [data_dir],
                                    output_types), num_threads)
    val_ds = val_ds.batch(cfg.batch_size).prefetch(512)
    val_init = iterator.make_initializer(val_ds)  #

    ds_handles = dict(next_ele=iterator.get_next(),
                      train_init=train_init,
                      val_init=val_init,
                      train_steps_per_epoch=train_steps_per_epoch,
                      val_num_batch=val_num_batch)
    return ds_handles
示例#3
0
 def __init__(self, cfg):
     self.ex_process = extra_processing(cfg.im_size, cfg.num_class,
                                        cfg.preprocess, cfg.loss)
示例#4
0
def evaluation(mode, sess, args, cfg, model=None, pkl_dir=None, log=False):
    """
        Args:
            mode: ["during_training", "eval", "eval_mutli"]
            model: The difference between the modes is as follow.
                For eval during training (the training model will be passed). 
                You need to create one at the very beginning of the eval 
                in "eval_multi" mode (outside of this function since it will be
                ran in a loop). 
                "Eval" mode will create a model for you. (This mode is not
                recommended)
            pkl_dir: Provided for an eval during training to overwrite the args.
    """
    info_paths = []
    for folder in args.testset_dir:
        info_paths += get_infos(folder)
    if args.eval_debug:
        random.shuffle(info_paths)
    else:
        info_paths = sorted(info_paths, key=lambda info: info[0])
    if mode == "eval":  # This mode is deprecated
        model = tf_model(cfg, False)
        saver = tf.train.Saver()
        saver.restore(sess, args.model_file)
    elif mode == "during_training":
        args.pkl_dir = pkl_dir
    elif mode == "eval_multi":
        saver = tf.train.Saver()
        saver.restore(sess, args.model_file)

    pbar = tqdm(total=len(info_paths))
    if os.path.exists(args.pkl_dir):
        input("Result file already exists. Press enter to \
            continue and overwrite it when inference is done...")
    all_result = []
    for info in info_paths:
        img_file, lab_file = info[0:2]
        try:
            img_ori, lab_ori = sitk.ReadImage(
                img_file,
                sitk.sitkFloat32), sitk.ReadImage(lab_file, sitk.sitkInt16)
            img_arr, lab_arr = sitk.GetArrayFromImage(
                img_ori), sitk.GetArrayFromImage(lab_ori)
        except:
            continue
        depth, ori_shape = img_arr.shape[0], img_arr.shape[1:]
        pneumonia_type = Pneu_type(img_file, False)
        spacing = img_ori.GetSpacing()
        ex_processing = extra_processing(cfg.im_size,
                                         cfg.num_class,
                                         cfg.preprocess,
                                         cfg.loss,
                                         og_shape=ori_shape[::-1])
        dis_arr, lab_arr = ex_processing.batch_preprocess(
            img_arr, lab_arr, False)

        pred_ = []
        segs = cfg.batch_size
        assert isinstance(segs, int) and (segs > 0) & (segs < 70), "Please"
        step = depth // segs + 1 if depth % segs != 0 else depth // segs
        for ii in range(step):
            if ii != step - 1:
                pp = sess.run(model.pred["seg_map"],
                              feed_dict={
                                  model.in_im:
                                  dis_arr[ii * segs:(ii + 1) * segs, ...]
                              })  #[0]
            else:
                pp = sess.run(
                    model.pred["seg_map"],
                    feed_dict={model.in_im: dis_arr[ii * segs:, ...]})  #[0]
            pp = 1 / (1 + np.exp(-pp))  # this only works for single class
            pred_.append(pp)
        dis_prd = np.concatenate(pred_, axis=0)
        # add the og version in
        if cfg.num_class == 1:
            dis_prd = dis_prd > 0.5
        else:
            if pneumonia_type == "common_pneu":
                cls_id = 2
            elif pneumonia_type == "covid_pneu":
                cls_id = 1
            else:
                raise Exception("Unknown condition!")
            dis_prd = np.argmax(dis_prd, -1) == cls_id
        dis_prd = ex_processing.batch_postprocess(dis_prd)
        if args.eval_debug:
            pred_nii = sitk.GetImageFromArray(dis_prd)
            pred_nii.CopyInformation(lab_ori)
            im_dir_list = img_file.split('/')
            debug_out = pj(os.path.dirname(args.model_file), args.eval_debug,
                           '_'.join([pneumonia_type] + im_dir_list[-5:-1]))
            if not os.path.exists(debug_out):
                os.makedirs(debug_out)
            out_niifile = pj(
                debug_out,
                os.path.basename(img_file).replace(".nii.gz", "_pred.nii.gz"))
            assert not os.path.exists(
                out_niifile)  # Otherwise existing files can be appended
            sitk.WriteImage(pred_nii, out_niifile)
            shutil.copy(img_file, debug_out)
            shutil.copy(lab_file, debug_out)
        else:
            score = dice_coef_pat(dis_prd, lab_arr)
            if score < 0.3:
                if args.viz:
                    viz_patient(img_arr, dis_prd, lab_arr,
                                pj(os.path.dirname(args.model_file), args.viz),
                                img_file)
                if log:
                    logging.info(os.path.dirname(lab_file))
                    logging.info(score)
                else:
                    print(os.path.dirname(lab_file))
                    print(score)
            all_result.append([img_file, score, round(spacing[-1], 1)])
        pbar.update(1)
    pbar.close()
    pickle.dump(all_result, open(args.pkl_dir, "bw"))
    show_dice(all_result, args.thickness_thres, log=log)
示例#5
0
def train(sess, args, cfg):
    train_dataset = ini_training_set(args, cfg)
    num_batches = len(train_dataset) // cfg.batch_size
    model = tf_model(cfg, False, args.gpus_to_use, num_batches)
    if args.resume:
        output_dir = os.path.dirname(args.resume)
    else:
        if os.path.exists(args.output_dir):
            output_dir = args.output_dir
        else:
            # If out_dir is not overwritten by an existing one,
            # create one automatically.
            output_dir = args.output_dir + time.strftime(
                "%m%d_%H%M_%S", time.localtime())
        if os.path.exists(output_dir):
            if not args.debug:
                input(
                    "The output directory already exists, please wait a moment and restart..."
                )
                # print("The output directory already exists, please wait a moment and restart...")
                # sys.exit()
        else:
            os.makedirs(output_dir)
        if not os.path.exists(pj(output_dir, os.path.basename(args.config))):
            shutil.copy(args.config, output_dir)
    logging.basicConfig(level=logging.DEBUG,
                        format="%(asctime)s %(message)s",
                        datefmt="%m-%d %H:%M",
                        filename=pj(output_dir, "training.log"),
                        filemode="a")
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    console.setFormatter(logging.Formatter("%(message)s"))
    logging.getLogger("").addHandler(console)
    saver = tf.train.Saver(max_to_keep=args.max_to_keep)
    num_para = np.sum(
        [np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])
    logging.info("Total number of trainable parameters: {:.3}M.\n".format(
        float(num_para) / 1e6))
    if args.resume:
        # So that an extended model can use pretrained weights?
        # Add another args for this? Does this affect the (free) model checking otherwise?
        # sess.run(tf.global_variables_initializer())
        saver.restore(sess, args.resume)
        # This needs to be changed if the naming rule changes
        epoch = int(
            (os.path.basename(args.resume).split('.')[0]).split('_')[-1])
    else:
        sess.run(tf.global_variables_initializer())
        epoch = 0
    while epoch < cfg.max_epoch:
        logging.info(f"Epoch {epoch + 1}\n")
        num_batches = 10 if args.debug else num_batches
        # with tf.contrib.tfprof.ProfileContext("", trace_steps=[], dump_steps=[]) as pctx:
        for i in range(num_batches):
            data_list = []
            for j in range(i * cfg.batch_size, (i + 1) * cfg.batch_size):
                # Careful with a batch size that can't be divided evenly by the number of gpus
                # when at the end during a multi-GPU training
                ex_process = extra_processing(cfg.im_size, cfg.num_class,
                                              cfg.preprocess, cfg.loss)
                im_ar, ann_ar = ex_process.preprocess(train_dataset[j][0],
                                                      train_dataset[j][1],
                                                      True)
                data_list.append((im_ar, ann_ar))
            data_ar = np.array(data_list)
            ret_loss, ret_lr, _, global_step = sess.run(
                [
                    model.loss, model.learning_rate, model.opt_op,
                    model.global_step
                ],
                feed_dict={
                    model.in_im: data_ar[:, 0, :, :, :],
                    model.in_gt: data_ar[:, 1, :, :, :],
                })
            # if i % 5 == 0:
            logging.info(
                f"Epoch progress: {i + 1} / {num_batches}, loss: {ret_loss:.5f}, lr: {ret_lr:.8f}, global step: {global_step}"
            )
        for _ in range(args.num_retry):
            try:
                ckpt_dir = pj(output_dir, f"epoch_{epoch + 1}.ckpt")
                saver.save(sess, ckpt_dir)
                # a = np.random.uniform(size=1)#
                # if a[0] < 0.9:#
                #     raise Exception("Hi!")#
                break
            except:
                logging.warning(
                    "Failed to save checkpoint. Retry after 2 minutes...")
                time.sleep(args.retry_waittime)
                # time.sleep(10)#
        if args.eval_while_train and not args.gpus_to_use:
            evaluation("during_training",
                       sess,
                       args,
                       cfg,
                       model,
                       ckpt_dir.replace(".ckpt", "_res.pkl"),
                       log=True)
        epoch += 1