示例#1
0
def main(_):
  #############################################################################
  # Configuration                                                             #
  #############################################################################
  # Checks for Python 3.6
  if sys.version_info[0] != 3:
    raise Exception(f"ERROR: You must use Python 3.6 but you are running "
                    f"Python {sys.version_info[0]}")

  # Prints Tensorflow version
  print(f"This code was developed and tested on TensorFlow 1.7.0. "
        f"Your TensorFlow version: {tf.__version__}.")

  # Defines {FLAGS.train_dir}, maybe based on {FLAGS.experiment_dir}
  if not FLAGS.experiment_name:
    raise Exception("You need to specify an --experiment_name or --train_dir.")
  FLAGS.train_dir = (FLAGS.train_dir
                     or os.path.join(EXPERIMENTS_DIR, FLAGS.experiment_name))

  # Sets GPU settings
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True

  #############################################################################
  # Train/dev split and model definition                                      #
  #############################################################################
  # Initializes model from atlas_model.py
  module = __import__("atlas_model")
  model_class = getattr(module, FLAGS.model_name)
  atlas_model = model_class(FLAGS)

  if FLAGS.mode == "train":
    if not os.path.exists(FLAGS.train_dir):
      os.makedirs(FLAGS.train_dir)

    # Sets logging configuration
    logging.basicConfig(filename=os.path.join(FLAGS.train_dir, "log.txt"),
                        level=logging.INFO)

    # Saves a record of flags as a .json file in {train_dir}
    # TODO: read the existing flags.json file
    with open(os.path.join(FLAGS.train_dir, "flags.json"), "w") as fout:
      flags = {k: v.serialize() for k, v in FLAGS.__flags.items()}
      json.dump(flags, fout)

    with tf.Session(config=config) as sess:
      # Loads the most recent model
      print('Initializing model')
      initialize_model(sess, atlas_model, FLAGS.train_dir, expect_exists=False,inputDims=[FLAGS.batch_size,FLAGS.slice_height,FLAGS.slice_width,FLAGS.scan_depth])

      # print('Initializing uninitialized')
      # def initialize_uninitialized(sess):
      #   global_vars          = tf.global_variables()
      #   is_not_initialized   = sess.run([tf.is_variable_initialized(var) for var in global_vars])
      #   not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]

      #   print([str(i.name) for i in not_initialized_vars]) # only for testing
      #   if len(not_initialized_vars):
      #       sess.run(tf.variables_initializer(not_initialized_vars))
      # initialize_uninitialized(sess)

      print('Running train')
      # Trains the model
      atlas_model.train(sess, *setup_train_dev_split(FLAGS))
  elif FLAGS.mode == "eval":
    with tf.Session(config=config) as sess:
      # Sets logging configuration
      logging.basicConfig(level=logging.INFO)

      # Loads the most recent model
      initialize_model(sess, atlas_model, FLAGS.train_dir, expect_exists=True,inputDims=[FLAGS.batch_size,FLAGS.slice_height,FLAGS.slice_width,FLAGS.scan_depth])

      # Shows examples from the dev set
      _, _, dev_input_paths, dev_target_mask_paths =\
        setup_train_dev_split(FLAGS)
      dev_dice = atlas_model.calculate_dice_coefficient(sess,
                                                        dev_input_paths,
                                                        dev_target_mask_paths,
                                                        "dev",
                                                        num_samples=FLAGS.dev_num_samples,
                                                        plot=True)
      logging.info(f"dev dice_coefficient: {dev_dice}")
示例#2
0
文件: main.py 项目: wroderick/atlas
def main(_):
    #############################################################################
    # Configuration                                                             #
    #############################################################################
    # Checks for Python 3.6
    if sys.version_info[0] != 3:
        raise Exception(f"ERROR: You must use Python 3.6 but you are running "
                        f"Python {sys.version_info[0]}")

    # Prints Tensorflow version
    print(f"This code was developed and tested on TensorFlow 1.7.0. "
          f"Your TensorFlow version: {tf.__version__}.")

    # Defines {FLAGS.train_dir}, maybe based on {FLAGS.experiment_dir}
    if not FLAGS.experiment_name:
        raise Exception(
            "You need to specify an --experiment_name or --train_dir.")
    FLAGS.train_dir = (FLAGS.train_dir
                       or os.path.join(EXPERIMENTS_DIR, FLAGS.experiment_name))

    # Sets GPU settings
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    #############################################################################
    # Train/dev split and model definition                                      #
    #############################################################################
    # Initializes model from atlas_model.py
    module = __import__("atlas_model")
    model_class = getattr(module, FLAGS.model_name)
    atlas_model = model_class(FLAGS)

    if FLAGS.mode == "train":
        if not os.path.exists(FLAGS.train_dir):
            os.makedirs(FLAGS.train_dir)

        # Sets logging configuration
        logging.basicConfig(filename=os.path.join(FLAGS.train_dir, "log.txt"),
                            level=logging.INFO)

        # Saves a record of flags as a .json file in {train_dir}
        # TODO: read the existing flags.json file
        with open(os.path.join(FLAGS.train_dir, "flags.json"), "w") as fout:
            flags = {k: v.serialize() for k, v in FLAGS.__flags.items()}
            json.dump(flags, fout)

        with tf.Session(config=config) as sess:
            # Loads the most recent model
            initialize_model(sess,
                             atlas_model,
                             FLAGS.train_dir,
                             expect_exists=False)
            # Trains the model
            atlas_model.train(sess, *setup_train_dev_split(FLAGS))
    elif FLAGS.mode == "eval":  #to use different eval filepath, python main.py --experiment_name=0015 --eval_filepath="data_output_masks0015" --mode=eval
        with tf.Session(config=config) as sess:
            # Sets logging configuration
            logging.basicConfig(level=logging.INFO)

            # Loads the most recent model
            initialize_model(sess,
                             atlas_model,
                             FLAGS.train_dir,
                             expect_exists=True)

            # Shows examples from the dev set
            _, _, dev_input_paths, dev_target_mask_paths =\
              setup_train_dev_split(FLAGS)

            #will change dev input paths if you want to (otherwise default is data)
            for i in range(len(dev_input_paths)):
                filepath = dev_input_paths[i][0]
                dev_input_paths[i][0] = filepath.replace(
                    'data', FLAGS.eval_filepath)


#      dev_dice = atlas_model.calculate_dice_coefficient(sess,
#                                                        dev_input_paths,
#                                                        dev_target_mask_paths,
#                                                        "dev",
#                                                        num_samples=1000,
#                                                        plot=True)
#      logging.info(f"dev dice_coefficient: {dev_dice}")

            dev_dice, dev_recall_pix, dev_precision_pix, dev_recall_img, dev_precision_img = atlas_model.calculate_acc_metrics(
                sess,
                dev_input_paths,
                dev_target_mask_paths,
                "dev",
                num_samples=1000,
                plot=False)
            logging.info(f"dev dice_coefficient: {dev_dice}")
            logging.info(f"dev recall_pix: {dev_recall_pix}")
            logging.info(f"dev precision_pix: {dev_precision_pix}")
            logging.info(f"dev recall_img: {dev_recall_img}")
            logging.info(f"dev precision_img: {dev_precision_img}")

    elif FLAGS.mode == "save_output_masks":  #run with this line: python main.py --experiment_name=0002 --mode=save_output_masks --num_epochs=3 --eval_every=100 --print_every=1 --save_every=100 --summary_every=20 --model_name=ATLASModel

        with tf.Session(config=config) as sess:
            # Sets logging configuration
            logging.basicConfig(level=logging.INFO)

            # Loads the most recent model
            initialize_model(sess,
                             atlas_model,
                             FLAGS.train_dir,
                             expect_exists=True)

            # Creates a new dataset of saved output masks

            # For each image in the dataset
            # Perform a forward pass and store the resulting mask
            # Use a boolean mask to form the final output
            # Save the final output

            prefix = os.path.join(FLAGS.data_dir, "ATLAS_R1.1")
            new_prefix = os.path.join(FLAGS.output_data_dir, "ATLAS_R1.1")
            if FLAGS.input_regex == None:
                input_paths_regex = "Site*/**/*_t1w_deface_stx/*.jpg"
            else:
                input_paths_regex = FLAGS.input_regex

            slice_paths = glob.glob(os.path.join(prefix, input_paths_regex),
                                    recursive=True)
            #iter = 0
            if FLAGS.dilation:
                struct1 = np.ones((4, 4))

            for curr_file_path in slice_paths:
                curr_img = io.imread(curr_file_path)
                # opens input, resizes it, converts to a numpy array
                curr_input = Image.open(curr_file_path).convert("L")
                curr_shape = (FLAGS.slice_height, FLAGS.slice_width)
                curr_input = curr_input.crop((0, 0) + curr_shape[::-1])
                curr_input = np.asarray(curr_input) / 255.0
                curr_input = np.expand_dims(curr_input, 0)
                predicted_mask = atlas_model.get_predicted_masks_for_training_example(
                    sess, curr_input)
                output_masked_image = curr_input * predicted_mask
                output_masked_image = np.squeeze(output_masked_image)

                if FLAGS.dilation:
                    #dilate the mask (image is 0s and 1s)
                    dilated_image = ndimage.binary_dilation(
                        output_masked_image,
                        structure=struct1,
                    ).astype(output_masked_image.dtype)
                    #mask the original image with the dilated masks
                    dilated_image = curr_input[0, :, :] * dilated_image
                    #output_masked_image = np.dstack((output_masked_image,output_masked_image,output_masked_image))
                    output_masked_image = np.dstack(
                        (dilated_image, dilated_image, dilated_image))
                elif FLAGS.gaussian_filter:
                    #apply gaussian filter to image
                    gauss_filt_image = ndimage.filters.gaussian_filter(
                        output_masked_image, sigma=1)
                    output_masked_image = np.dstack(
                        (gauss_filt_image, gauss_filt_image, gauss_filt_image))
                else:
                    output_masked_image = np.dstack(
                        (output_masked_image, output_masked_image,
                         output_masked_image))
                #create new filepath to output masked images
                old_folderpath = os.path.split(curr_file_path)[0]
                filename = os.path.split(curr_file_path)[1]
                new_slice_path = old_folderpath.replace(
                    FLAGS.data_dir, FLAGS.output_data_dir)
                #if folder doesn't exist, make it in specified folder
                if not os.path.exists(new_slice_path):
                    os.makedirs(new_slice_path)
                #save the image
                io.imsave(new_slice_path + '/' + filename,
                          output_masked_image,
                          quality=100)
                print("Finished saving file: " + new_slice_path + '/' +
                      filename)
                #iter += 1
                #outpath = "../data_output_masks/"
                #io.imsave(outpath + str(iter) + '.jpg',output_masked_image,quality=100)

    elif FLAGS.mode == "saliency_map":  #run with this line: python main.py --experiment_name=0002 --mode=saliency_map --model_name=ATLASModel --example_num=2
        # examples that work well: 2, 20
        with tf.Session(config=config) as sess:
            # Sets logging configuration
            logging.basicConfig(level=logging.INFO)

            # Loads the most recent model
            initialize_model(sess,
                             atlas_model,
                             FLAGS.train_dir,
                             expect_exists=True)

            # Gets the image from the dev set
            _, _, dev_input_paths, dev_target_mask_paths =\
                      setup_train_dev_split(FLAGS)
            curr_dev_input_path = dev_input_paths[FLAGS.example_num][0]
            curr_dev_target_mask_path = dev_target_mask_paths[
                FLAGS.example_num][0][0]

            # opens input, resizes it, converts to a numpy array
            curr_input = Image.open(curr_dev_input_path).convert("L")
            curr_shape = (FLAGS.slice_height, FLAGS.slice_width)
            curr_input = curr_input.crop((0, 0) + curr_shape[::-1])
            curr_input = np.asarray(curr_input) / 255.0
            curr_input_img = np.dstack((curr_input, curr_input, curr_input))
            curr_input = np.expand_dims(curr_input, 0)

            # opens target, resizes it, converts to a numpy array
            curr_target = Image.open(curr_dev_target_mask_path).convert("L")
            curr_target_shape = (FLAGS.slice_height, FLAGS.slice_width)
            curr_target = curr_target.crop((0, 0) + curr_target_shape[::-1])
            curr_target = np.asarray(curr_target) / 255.0
            curr_target_img = np.dstack(
                (curr_target, curr_target, curr_target))
            curr_target = np.expand_dims(curr_target, 0)

            # gets the predicted mask
            predicted_mask = atlas_model.get_predicted_masks_for_training_example(
                sess, curr_input)
            predicted_mask_img = np.squeeze(predicted_mask)
            predicted_mask_img = predicted_mask_img.astype(float)
            predicted_mask_img = np.dstack(
                (predicted_mask_img, predicted_mask_img, predicted_mask_img))

            # Finds the gradients with respect to the input
            grads_wrt_input = atlas_model.get_grads_wrt_input(
                sess, curr_input, curr_target)
            grads_wrt_input = np.squeeze(grads_wrt_input)
            grads_wrt_input = np.absolute(grads_wrt_input)
            grads_wrt_input = np.power(grads_wrt_input, 1. / 3.)
            grads_wrt_input = grads_wrt_input / np.amax(grads_wrt_input)
            output_grads_wrt_input_image = np.dstack(
                (grads_wrt_input, grads_wrt_input, grads_wrt_input))

            # Plot
            plt.subplot(1, 3, 1)
            plt.imshow(curr_input_img)
            #plt.axis('off')
            plt.subplot(1, 3, 2)
            #plt.imshow(curr_input_img)
            #curr_img_mask_overlay = np.zeros(curr_target_img.shape)
            curr_img_mask_overlay = np.copy(curr_input_img)
            curr_target_img_flags = curr_target_img[:, :, 0] > 0.5
            predicted_mask_img_flags = predicted_mask_img[:, :, 0] > 0.5
            curr_img_mask_overlay[curr_target_img_flags] = 0.
            curr_img_mask_overlay[predicted_mask_img_flags] = 0.
            curr_img_mask_overlay[curr_target_img_flags, 1] = 155. / 255.
            curr_img_mask_overlay[curr_target_img_flags, 2] = 218. / 255.
            #curr_img_mask_overlay[predicted_mask_img_flags,1] = 1.
            curr_img_mask_overlay[predicted_mask_img_flags, 0] = 1.
            #curr_img_mask_overlay[:,:,0] = curr_target_img_array
            #curr_img_mask_overlay[:,:,1] = predicted_mask_img_array
            #plt.imshow(curr_img_mask_overlay, alpha=0.2)
            plt.imshow(curr_img_mask_overlay)
            #plt.axis('off')
            plt.subplot(1, 3, 3)
            plt.imshow(output_grads_wrt_input_image)
            #plt.axis('off')
            #plt.savefig("../plots/SaliencyMap.pdf",transparent=True, bbox_inches='tight',dpi=3000)
            plt.show()

            plt.subplot(1, 3, 1)
            plt.imshow(curr_input_img)
            plt.axis('off')
            plt.subplot(1, 3, 2)
            #plt.imshow(curr_input_img)
            #curr_img_mask_overlay = np.zeros(curr_target_img.shape)
            #curr_img_mask_overlay[:,:,0] = curr_target_img[:,:,1]
            #curr_img_mask_overlay[:,:,1] = predicted_mask_img[:,:,1]
            #plt.imshow(curr_img_mask_overlay, alpha=0.2)
            plt.imshow(curr_img_mask_overlay)
            plt.axis('off')
            plt.subplot(1, 3, 3)
            plt.imshow(output_grads_wrt_input_image)
            plt.axis('off')
            #plt.savefig("../plots/SaliencyMap.pdf",transparent=True, bbox_inches='tight',dpi=3000)
            plt.subplots_adjust(left=0.05,
                                bottom=0.05,
                                right=0.95,
                                top=0.95,
                                wspace=0.01,
                                hspace=0.01)
            plt.show()
示例#3
0
def main(_):
  #############################################################################
  # Configuration                                                             #
  #############################################################################
  # Checks for Python 3.6
  if sys.version_info[0] != 3:
    raise Exception(f"ERROR: You must use Python 3.6 but you are running "
                    f"Python {sys.version_info[0]}")

  # Prints Tensorflow version
  print(f"This code was developed and tested on TensorFlow 1.7.0. "
        f"Your TensorFlow version: {tf.__version__}.")

  # Defines {FLAGS.train_dir}, maybe based on {FLAGS.experiment_dir}
  if not FLAGS.experiment_name:
    raise Exception("You need to specify an --experiment_name or --train_dir.")
  FLAGS.train_dir = (FLAGS.train_dir
                     or os.path.join(EXPERIMENTS_DIR, FLAGS.experiment_name))

  # Sets GPU settings
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True  

  #############################################################################
  # Train/dev split and model definition                                      #
  #############################################################################
  # Initializes model from atlas_model.py
  module = __import__("atlas_model")
  model_class = getattr(module, FLAGS.model_name)
  atlas_model = model_class(FLAGS)

  if FLAGS.mode == "train":
    if not os.path.exists(FLAGS.train_dir):
      os.makedirs(FLAGS.train_dir)

    # Sets logging configuration
    logging.basicConfig(filename=os.path.join(FLAGS.train_dir, "log.txt"),
                        level=logging.INFO)

    # Saves a record of flags as a .json file in {train_dir}
    # TODO: read the existing flags.json file
    with open(os.path.join(FLAGS.train_dir, "flags.json"), "w") as fout:
      flags = {k: v.serialize() for k, v in FLAGS.__flags.items()}
      json.dump(flags, fout)

    with tf.Session(config=config) as sess:
      # Loads the most recent model, or initializes a new one
      initialize_model(sess, atlas_model, FLAGS.train_dir, expect_exists=False)

      # Trains the model
      atlas_model.train(sess, *setup_train_dev_split(FLAGS))
  elif FLAGS.mode == "eval":
    with tf.Session(config=config) as sess:
      # Sets logging configuration
      logging.basicConfig(level=logging.INFO)

      # Loads the most recent model
      initialize_model(sess, atlas_model, FLAGS.train_dir, expect_exists=True)

      # Shows examples from the dev set
      _, _, dev_input_paths, dev_target_mask_paths =\
        setup_train_dev_split(FLAGS)

      dev_dice = atlas_model.calculate_dice_coefficient(sess,
                                                        dev_input_paths,
                                                        dev_target_mask_paths,
                                                        "dev",
                                                        num_samples=FLAGS.num_samples,
                                                        plot=True)
      logging.info(f"dev dice_coefficient: {dev_dice}")
  elif FLAGS.mode == "print":
    with tf.Session(config=config) as sess:
      # Sets logging configuration
      logging.basicConfig(level=logging.INFO)

      # Loads the most recent model
      initialize_model(sess, atlas_model, FLAGS.train_dir, expect_exists=True)
      trained_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
      var_count=0
      
      for var in trained_vars:
        size=1
        for i in range(len(var.shape)):
          size = size*int(var.shape[i])
        var_count += size
        if 'W' in var.name:
          print("-- Name:", var.name, "-- Value:", 1/size*tf.reduce_sum(var).eval())

      print("Total trainable params =", var_count)
示例#4
0
文件: main.py 项目: davidiot/atlas
def main(_):
    #############################################################################
    # Configuration                                                             #
    #############################################################################
    # Checks for Python 3.6
    if sys.version_info[0] != 3:
        raise Exception(f"ERROR: You must use Python 3.6 but you are running "
                        f"Python {sys.version_info[0]}")

    # Prints Tensorflow version
    print(f"This code was developed and tested on TensorFlow 1.7.0. "
          f"Your TensorFlow version: {tf.__version__}.")

    # Defines {FLAGS.train_dir}, maybe based on {FLAGS.experiment_dir}
    if FLAGS.mode != 'box' and not FLAGS.experiment_name:
        raise Exception("You need to specify an --experiment_name or --train_dir.")
    FLAGS.train_dir = (FLAGS.train_dir
                       or os.path.join(EXPERIMENTS_DIR, FLAGS.experiment_name))

    # Sets GPU settings
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    #############################################################################
    # Train/dev split and model definition                                      #
    #############################################################################
    # Initializes model from atlas_model.py
    module = __import__("atlas_model")
    model_class = getattr(module, FLAGS.model_name)
    atlas_model = model_class(FLAGS)

    if FLAGS.mode == "split":
        if not os.path.exists(FLAGS.train_dir):
            os.makedirs(FLAGS.train_dir)
        setup_train_dev_split(FLAGS)

    if FLAGS.mode == "train":
        if not os.path.exists(FLAGS.train_dir):
            os.makedirs(FLAGS.train_dir)

        # Sets logging configuration
        logging.basicConfig(filename=os.path.join(FLAGS.train_dir, "log.txt"),
                            level=logging.INFO)

        # Saves a record of flags as a .json file in {train_dir}
        # TODO: read the existing flags.json file
        with open(os.path.join(FLAGS.train_dir, "flags.json"), "w") as fout:
            flags = {k: v.serialize() for k, v in FLAGS.__flags.items()}
            json.dump(flags, fout)

        with tf.Session(config=config) as sess:
            # Loads the most recent model
            initialize_model(sess, atlas_model, FLAGS.train_dir, expect_exists=False)

            # Trains the model
            atlas_model.train(sess, *setup_train_dev_split(FLAGS))
    elif FLAGS.mode == "eval":
        with tf.Session(config=config) as sess:
            # Sets logging configuration
            logging.basicConfig(level=logging.INFO)

            # Loads the most recent model
            initialize_model(sess, atlas_model, FLAGS.train_dir, expect_exists=True)

            # Shows examples from the dev set
            _, _, dev_input_paths, dev_target_mask_paths = \
                setup_train_dev_split(FLAGS)
            dev_dice = atlas_model.calculate_dice_coefficient(sess,
                                                              dev_input_paths,
                                                              dev_target_mask_paths,
                                                              "dev",
                                                              num_samples=1000,
                                                              plot=True)
            logging.info(f"dev dice_coefficient: {dev_dice}")
    elif FLAGS.mode == "box":
        bb.generate_boxes(FLAGS)