Esempio n. 1
0
    def calculate_loss(self,
                       sess,
                       input_paths,
                       target_mask_paths,
                       dataset,
                       num_samples=None):
        """
    Calculates the loss for a dataset, represented by a list of {input_paths}
    and {target_mask_paths}.

    Inputs:
    - sess: A TensorFlow Session object.
    - input_paths: A list of Python strs that represent pathnames to input
      image files.
    - target_mask_paths: A list of Python strs that represent pathnames to
      target mask files.
    - dataset: A Python str that represents the dataset being tested. Options:
      {train,dev}. Just for logging purposes.
    - num_samples: A Python int that represents the number of samples to test.
      If num_samples=None, then test whole dataset.

    Outputs:
    - loss: A Python float that represents the average loss across the sampled
      examples.
    """
        logging.info(f"Calculating loss for {num_samples} examples from "
                     f"{dataset}...")
        tic = time.time()

        loss_per_batch, batch_sizes = [], []

        sbg = SliceBatchGenerator(
            input_paths,
            target_mask_paths,
            self.FLAGS.batch_size,
            num_samples=num_samples,
            shape=tuple([self.FLAGS.slice_height, self.FLAGS.slice_width] + (
                [self.FLAGS.scan_depth] if self.FLAGS.use_volumetric else [])),
            use_fake_target_masks=self.FLAGS.use_fake_target_masks)
        # Iterates over batches
        for batch in sbg.get_batch():
            # Gets loss for this batch
            loss = self.get_loss_for_batch(sess, batch)
            cur_batch_size = batch.batch_size
            loss_per_batch.append(loss * cur_batch_size)
            batch_sizes.append(cur_batch_size)

        # Calculates average loss
        total_num_examples = sum(batch_sizes)

        # Overall loss is total loss divided by total number of examples
        loss = sum(loss_per_batch) / float(total_num_examples)

        toc = time.time()
        logging.info(f"Calculating loss took {toc-tic} sec.")
        return loss
Esempio n. 2
0
    def get_dev_loss(self,
                     sess,
                     dev_input_paths,
                     dev_target_mask_paths,
                     num_samples=None):
        """
    Get loss for entire dev set.

    Inputs:
    - sess: A TensorFlow Session object.
    - dev_input_paths: A list of Python strs that represent pathnames to input
      image files in the dev set.
    - dev_target_mask_paths: A list of Python strs that represent pathnames to
      input target mask files in the dev set.
    - num_samples: A Python int or None. If None, then evaluates on the entire
      dev set.

    Outputs:
    - dev_loss: A Python float that represents the average loss across the dev
      set.
    """
        logging.info("Calculating dev loss...")
        tic = time.time()

        loss_per_batch, batch_sizes = [], []

        sbg = SliceBatchGenerator(
            dev_input_paths,
            dev_target_mask_paths,
            self.FLAGS.batch_size,
            num_samples=num_samples,
            shape=(self.FLAGS.slice_height, self.FLAGS.slice_width),
            use_fake_target_masks=self.FLAGS.use_fake_target_masks)
        # Iterates over dev set batches
        for batch in sbg.get_batch():
            # Gets loss for this batch
            loss = self.get_loss(sess, batch)
            cur_batch_size = batch.batch_size
            loss_per_batch.append(loss * cur_batch_size)
            batch_sizes.append(cur_batch_size)

        # Calculates average loss
        total_num_examples = sum(batch_sizes)

        # Overall loss is total loss divided by total number of examples
        dev_loss = sum(loss_per_batch) / float(total_num_examples)

        toc = time.time()
        logging.info(f"Calculating dev loss took {toc-tic} sec.")
        return dev_loss
Esempio n. 3
0
 def get_batch_generator(self,
                         input_paths,
                         target_mask_paths,
                         num_samples=None,
                         flip_images=True):
     return SliceBatchGenerator(
         input_paths,
         target_mask_paths,
         self.FLAGS.batch_size,
         num_samples=num_samples,
         shape=(self.FLAGS.slice_height, self.FLAGS.slice_width),
         use_fake_target_masks=self.FLAGS.use_fake_target_masks)
Esempio n. 4
0
  def train(self,
            sess,
            train_input_paths,
            train_target_mask_paths,
            dev_input_paths,
            dev_target_mask_paths):
    """
    Defines the training loop.

    Inputs:
    - sess: A TensorFlow Session object.
    - {train,dev}_{input_paths,target_mask_paths}: A list of Python strs
      that represent pathnames to input image files and target mask files.
    """
    params = tf.trainable_variables()
    num_params = sum(map(lambda t: np.prod(tf.shape(t.value()).eval()), params))

    # We will keep track of exponentially-smoothed loss
    exp_loss = None

    # Checkpoint management.
    # We keep one latest checkpoint, and one best checkpoint (early stopping)
    checkpoint_path = os.path.join(self.FLAGS.train_dir, "qa.ckpt")
    best_dev_dice_coefficient = None

    # For TensorBoard
    summary_writer = tf.summary.FileWriter(self.FLAGS.train_dir, sess.graph)

    epoch = 0
    num_epochs = self.FLAGS.num_epochs
    while num_epochs == None or epoch < num_epochs:
      epoch += 1

      # Loops over batches
      sbg = SliceBatchGenerator(train_input_paths,
                                train_target_mask_paths,
                                self.FLAGS.batch_size,
                                shape=(self.FLAGS.slice_height,
                                       self.FLAGS.slice_width),
                                use_fake_target_masks=self.FLAGS.use_fake_target_masks)
      num_epochs_str = str(num_epochs) if num_epochs != None else "indefinite"
      for batch in tqdm(sbg.get_batch(),
                        desc=f"Epoch {epoch}/{num_epochs_str}",
                        total=sbg.get_num_batches()):
        # Runs training iteration
        loss, global_step, param_norm, grad_norm =\
          self.run_train_iter(sess, batch, summary_writer)

        # Updates exponentially-smoothed loss
        if not exp_loss:  # first iter
          exp_loss = loss
        else:
          exp_loss = 0.99 * exp_loss + 0.01 * loss

        # Sometimes prints info
        if global_step % self.FLAGS.print_every == 0:
          logging.info(
            f"epoch {epoch}, "
            f"global_step {global_step}, "
            f"loss {loss}, "
            f"exp_loss {exp_loss}, "
            f"grad norm {grad_norm}, "
            f"param norm {param_norm}")

        # Sometimes saves model
        if (global_step % self.FLAGS.save_every == 0
            or global_step == sbg.get_num_batches()):
          self.saver.save(sess, checkpoint_path, global_step=global_step)

        # Sometimes evaluates model on dev loss, train F1/EM and dev F1/EM
        if global_step % self.FLAGS.eval_every == 0:
          # Logs loss for entire dev set to TensorBoard
          dev_loss = self.calculate_loss(sess,
                                         dev_input_paths,
                                         dev_target_mask_paths,
                                         "dev",
                                         self.FLAGS.dev_num_samples)
          logging.info(f"epoch {epoch}, "
                       f"global_step {global_step}, "
                       f"dev_loss {dev_loss}")
          utils.write_summary(dev_loss,
                              "dev/loss",
                              summary_writer,
                              global_step)

          # Logs dice coefficient on train set to TensorBoard
          train_dice = self.calculate_dice_coefficient(sess,
                                                       train_input_paths,
                                                       train_target_mask_paths,
                                                       "train")
          logging.info(f"epoch {epoch}, "
                       f"global_step {global_step}, "
                       f"train dice_coefficient: {train_dice}")
          utils.write_summary(train_dice,
                              "train/dice",
                              summary_writer,
                              global_step)

          # Logs dice coefficient on dev set to TensorBoard
          dev_dice = self.calculate_dice_coefficient(sess,
                                                     dev_input_paths,
                                                     dev_target_mask_paths,
                                                     "dev")
          logging.info(f"epoch {epoch}, "
                       f"global_step {global_step}, "
                       f"dev dice_coefficient: {dev_dice}")
          utils.write_summary(dev_dice,
                              "dev/dice",
                              summary_writer,
                              global_step)
      # end for batch in sbg.get_batch
    # end while num_epochs == 0 or epoch < num_epochs
    sys.stdout.flush()
Esempio n. 5
0
  def calculate_dice_coefficient(self,
                                 sess,
                                 input_paths,
                                 target_mask_paths,
                                 dataset,
                                 num_samples=100,
                                 plot=False,
                                 print_to_screen=False):
    """
    Calculates the dice coefficient score for a dataset, represented by a
    list of {input_paths} and {target_mask_paths}.

    Inputs:
    - sess: A TensorFlow Session object.
    - input_paths: A list of Python strs that represent pathnames to input
      image files.
    - target_mask_paths: A list of Python strs that represent pathnames to
      target mask files.
    - dataset: A Python str that represents the dataset being tested. Options:
      {train,dev}. Just for logging purposes.
    - num_samples: A Python int that represents the number of samples to test.
      If num_samples=None, then test whole dataset.
    - plot: A Python bool. If True, plots each example to screen.

    Outputs:
    - dice_coefficient: A Python float that represents the average dice
      coefficient across the sampled examples.
    """
    logging.info(f"Calculating dice coefficient for {num_samples} examples "
                 f"from {dataset}...")
    tic = time.time()

    dice_coefficient_total = 0.
    num_examples = 0

    sbg = SliceBatchGenerator(input_paths,
                              target_mask_paths,
                              self.FLAGS.batch_size,
                              shape=(self.FLAGS.slice_height,
                                     self.FLAGS.slice_width),
                              use_fake_target_masks=self.FLAGS.use_fake_target_masks)
    for batch in sbg.get_batch():
      predicted_masks = self.get_predicted_masks_for_batch(sess, batch)

      zipped_masks = zip(predicted_masks,
                         batch.target_masks_batch,
                         batch.input_paths_batch,
                         batch.target_mask_path_lists_batch)
      for idx, (predicted_mask,
                target_mask,
                input_path,
                target_mask_path_list) in enumerate(zipped_masks):
        dice_coefficient = utils.dice_coefficient(predicted_mask, target_mask)
        if dice_coefficient >= 0.0:
          dice_coefficient_total += dice_coefficient
          num_examples += 1

          if print_to_screen:
            # Whee! We predicted at least one lesion pixel!
            logging.info(f"Dice coefficient of valid example {num_examples}: "
                         f"{dice_coefficient}")
          if plot:
            f, axarr = plt.subplots(1, 2)
            f.suptitle(input_path)
            axarr[0].imshow(predicted_mask)
            axarr[0].set_title("Predicted")
            axarr[1].imshow(target_mask)
            axarr[1].set_title("Target")
            examples_dir = os.path.join(self.FLAGS.train_dir, "examples")
            if not os.path.exists(examples_dir):
              os.makedirs(examples_dir)
            f.savefig(os.path.join(examples_dir, str(num_examples).zfill(4)))

        if num_samples != None and num_examples >= num_samples:
          break

      if num_samples != None and num_examples >= num_samples:
        break

    dice_coefficient_mean = dice_coefficient_total / num_examples

    toc = time.time()
    logging.info(f"Calculating dice coefficient took {toc-tic} sec.")
    return dice_coefficient_mean
Esempio n. 6
0
  def calculate_dice_coefficient(self,
                                 sess,
                                 input_paths,
                                 target_mask_paths,
                                 dataset,
                                 num_samples=100,
                                 plot=False,
                                 print_to_screen=False):
    """
    Calculates the dice coefficient score for a dataset, represented by a
    list of {input_paths} and {target_mask_paths}.

    Inputs:
    - sess: A TensorFlow Session object.
    - input_paths: A list of Python strs that represent pathnames to input
      image files.
    - target_mask_paths: A list of Python strs that represent pathnames to
      target mask files.
    - dataset: A Python str that represents the dataset being tested. Options:
      {train,dev}. Just for logging purposes.
    - num_samples: A Python int that represents the number of samples to test.
      If num_samples=None, then test whole dataset.
    - plot: A Python bool. If True, plots each example to screen.

    Outputs:
    - dice_coefficient: A Python float that represents the average dice
      coefficient across the sampled examples.
    """
    logging.info(f"Calculating dice coefficient for {num_samples} examples "
                 f"from {dataset}...")
    tic = time.time()

    dice_coefficient_total = 0.
    num_examples = 0

    # To be used to save mask sizes for comparison
    predicted_mask_sizes = []
    target_mask_sizes = []

    sbg = SliceBatchGenerator(input_paths,
                              target_mask_paths,
                              self.FLAGS.batch_size,
                              shape=(self.FLAGS.slice_height,
                                     self.FLAGS.slice_width),
                              use_fake_target_masks=self.FLAGS.use_fake_target_masks)
    for batch in sbg.get_batch():
      predicted_masks = self.get_predicted_masks_for_batch(sess, batch)

      zipped_masks = zip(predicted_masks,
                         batch.target_masks_batch,
                         batch.input_paths_batch,
                         batch.target_mask_path_lists_batch)
      for idx, (predicted_mask,
                target_mask,
                input_path,
                target_mask_path_list) in enumerate(zipped_masks):
        dice_coefficient = utils.dice_coefficient(predicted_mask, target_mask)
        if dice_coefficient >= 0.0:
          dice_coefficient_total += dice_coefficient
          num_examples += 1

          if print_to_screen:
            # Whee! We predicted at least one lesion pixel!
            logging.info(f"Dice coefficient of valid example {num_examples}: "
                         f"{dice_coefficient}")
          if plot:

            if self.FLAGS.mode == 'eval':
              # Save mask sizes for comparison
              predicted_mask_sizes.append(np.sum(predicted_mask))
              target_mask_sizes.append(np.sum(target_mask))

              f, axarr = plt.subplots(1, 2)
              f.suptitle(input_path)
              axarr[0].imshow(predicted_mask)
              axarr[0].set_title("Predicted")
              axarr[1].imshow(target_mask)
              axarr[1].set_title("Target")
              examples_dir = os.path.join(self.FLAGS.train_dir, "examples")
              if not os.path.exists(examples_dir):
                os.makedirs(examples_dir)
              f.savefig(os.path.join(examples_dir, str(num_examples).zfill(4)))

        if num_samples != None and num_examples >= num_samples:
          break

      if num_samples != None and num_examples >= num_samples:
        break

    if num_samples < 200 and self.FLAGS.mode == 'eval':
      predicted_mask_sizes = np.array(predicted_mask_sizes)
      target_mask_sizes =np.array(target_mask_sizes)
      args = predicted_mask_sizes.argsort()
      predicted_mask_sizes = predicted_mask_sizes[args]
      target_mask_sizes = target_mask_sizes[args]
      fig, ax = plt.subplots()
      ind = 2*np.arange(num_examples)
      rects1 = ax.bar(ind, predicted_mask_sizes, 0.5, color='r')
      rects2 = ax.bar(ind + 0.5, target_mask_sizes, 0.5, color='b')
      ax.set_ylabel('Size (in pixels)')
      ax.set_title('Predicted and Target Mask Sizes')
      ax.set_xticks(ind + 0.25)
      ax.set_xticklabels(0.5*ind)
      ax.legend((rects1[0], rects2[0]), ('Predicted', 'Target'))
      fig.savefig(os.path.join(self.FLAGS.train_dir, 'relative_sizes'))

    dice_coefficient_mean = dice_coefficient_total / num_examples

    toc = time.time()
    logging.info(f"Calculating dice coefficient took {toc-tic} sec.")
    return dice_coefficient_mean