def calculate_loss(self, sess, input_paths, target_mask_paths, dataset, num_samples=None): """ Calculates the loss for a dataset, represented by a list of {input_paths} and {target_mask_paths}. Inputs: - sess: A TensorFlow Session object. - input_paths: A list of Python strs that represent pathnames to input image files. - target_mask_paths: A list of Python strs that represent pathnames to target mask files. - dataset: A Python str that represents the dataset being tested. Options: {train,dev}. Just for logging purposes. - num_samples: A Python int that represents the number of samples to test. If num_samples=None, then test whole dataset. Outputs: - loss: A Python float that represents the average loss across the sampled examples. """ logging.info(f"Calculating loss for {num_samples} examples from " f"{dataset}...") tic = time.time() loss_per_batch, batch_sizes = [], [] sbg = SliceBatchGenerator( input_paths, target_mask_paths, self.FLAGS.batch_size, num_samples=num_samples, shape=tuple([self.FLAGS.slice_height, self.FLAGS.slice_width] + ( [self.FLAGS.scan_depth] if self.FLAGS.use_volumetric else [])), use_fake_target_masks=self.FLAGS.use_fake_target_masks) # Iterates over batches for batch in sbg.get_batch(): # Gets loss for this batch loss = self.get_loss_for_batch(sess, batch) cur_batch_size = batch.batch_size loss_per_batch.append(loss * cur_batch_size) batch_sizes.append(cur_batch_size) # Calculates average loss total_num_examples = sum(batch_sizes) # Overall loss is total loss divided by total number of examples loss = sum(loss_per_batch) / float(total_num_examples) toc = time.time() logging.info(f"Calculating loss took {toc-tic} sec.") return loss
def get_dev_loss(self, sess, dev_input_paths, dev_target_mask_paths, num_samples=None): """ Get loss for entire dev set. Inputs: - sess: A TensorFlow Session object. - dev_input_paths: A list of Python strs that represent pathnames to input image files in the dev set. - dev_target_mask_paths: A list of Python strs that represent pathnames to input target mask files in the dev set. - num_samples: A Python int or None. If None, then evaluates on the entire dev set. Outputs: - dev_loss: A Python float that represents the average loss across the dev set. """ logging.info("Calculating dev loss...") tic = time.time() loss_per_batch, batch_sizes = [], [] sbg = SliceBatchGenerator( dev_input_paths, dev_target_mask_paths, self.FLAGS.batch_size, num_samples=num_samples, shape=(self.FLAGS.slice_height, self.FLAGS.slice_width), use_fake_target_masks=self.FLAGS.use_fake_target_masks) # Iterates over dev set batches for batch in sbg.get_batch(): # Gets loss for this batch loss = self.get_loss(sess, batch) cur_batch_size = batch.batch_size loss_per_batch.append(loss * cur_batch_size) batch_sizes.append(cur_batch_size) # Calculates average loss total_num_examples = sum(batch_sizes) # Overall loss is total loss divided by total number of examples dev_loss = sum(loss_per_batch) / float(total_num_examples) toc = time.time() logging.info(f"Calculating dev loss took {toc-tic} sec.") return dev_loss
def get_batch_generator(self, input_paths, target_mask_paths, num_samples=None, flip_images=True): return SliceBatchGenerator( input_paths, target_mask_paths, self.FLAGS.batch_size, num_samples=num_samples, shape=(self.FLAGS.slice_height, self.FLAGS.slice_width), use_fake_target_masks=self.FLAGS.use_fake_target_masks)
def train(self, sess, train_input_paths, train_target_mask_paths, dev_input_paths, dev_target_mask_paths): """ Defines the training loop. Inputs: - sess: A TensorFlow Session object. - {train,dev}_{input_paths,target_mask_paths}: A list of Python strs that represent pathnames to input image files and target mask files. """ params = tf.trainable_variables() num_params = sum(map(lambda t: np.prod(tf.shape(t.value()).eval()), params)) # We will keep track of exponentially-smoothed loss exp_loss = None # Checkpoint management. # We keep one latest checkpoint, and one best checkpoint (early stopping) checkpoint_path = os.path.join(self.FLAGS.train_dir, "qa.ckpt") best_dev_dice_coefficient = None # For TensorBoard summary_writer = tf.summary.FileWriter(self.FLAGS.train_dir, sess.graph) epoch = 0 num_epochs = self.FLAGS.num_epochs while num_epochs == None or epoch < num_epochs: epoch += 1 # Loops over batches sbg = SliceBatchGenerator(train_input_paths, train_target_mask_paths, self.FLAGS.batch_size, shape=(self.FLAGS.slice_height, self.FLAGS.slice_width), use_fake_target_masks=self.FLAGS.use_fake_target_masks) num_epochs_str = str(num_epochs) if num_epochs != None else "indefinite" for batch in tqdm(sbg.get_batch(), desc=f"Epoch {epoch}/{num_epochs_str}", total=sbg.get_num_batches()): # Runs training iteration loss, global_step, param_norm, grad_norm =\ self.run_train_iter(sess, batch, summary_writer) # Updates exponentially-smoothed loss if not exp_loss: # first iter exp_loss = loss else: exp_loss = 0.99 * exp_loss + 0.01 * loss # Sometimes prints info if global_step % self.FLAGS.print_every == 0: logging.info( f"epoch {epoch}, " f"global_step {global_step}, " f"loss {loss}, " f"exp_loss {exp_loss}, " f"grad norm {grad_norm}, " f"param norm {param_norm}") # Sometimes saves model if (global_step % self.FLAGS.save_every == 0 or global_step == sbg.get_num_batches()): self.saver.save(sess, checkpoint_path, global_step=global_step) # Sometimes evaluates model on dev loss, train F1/EM and dev F1/EM if global_step % self.FLAGS.eval_every == 0: # Logs loss for entire dev set to TensorBoard dev_loss = self.calculate_loss(sess, dev_input_paths, dev_target_mask_paths, "dev", self.FLAGS.dev_num_samples) logging.info(f"epoch {epoch}, " f"global_step {global_step}, " f"dev_loss {dev_loss}") utils.write_summary(dev_loss, "dev/loss", summary_writer, global_step) # Logs dice coefficient on train set to TensorBoard train_dice = self.calculate_dice_coefficient(sess, train_input_paths, train_target_mask_paths, "train") logging.info(f"epoch {epoch}, " f"global_step {global_step}, " f"train dice_coefficient: {train_dice}") utils.write_summary(train_dice, "train/dice", summary_writer, global_step) # Logs dice coefficient on dev set to TensorBoard dev_dice = self.calculate_dice_coefficient(sess, dev_input_paths, dev_target_mask_paths, "dev") logging.info(f"epoch {epoch}, " f"global_step {global_step}, " f"dev dice_coefficient: {dev_dice}") utils.write_summary(dev_dice, "dev/dice", summary_writer, global_step) # end for batch in sbg.get_batch # end while num_epochs == 0 or epoch < num_epochs sys.stdout.flush()
def calculate_dice_coefficient(self, sess, input_paths, target_mask_paths, dataset, num_samples=100, plot=False, print_to_screen=False): """ Calculates the dice coefficient score for a dataset, represented by a list of {input_paths} and {target_mask_paths}. Inputs: - sess: A TensorFlow Session object. - input_paths: A list of Python strs that represent pathnames to input image files. - target_mask_paths: A list of Python strs that represent pathnames to target mask files. - dataset: A Python str that represents the dataset being tested. Options: {train,dev}. Just for logging purposes. - num_samples: A Python int that represents the number of samples to test. If num_samples=None, then test whole dataset. - plot: A Python bool. If True, plots each example to screen. Outputs: - dice_coefficient: A Python float that represents the average dice coefficient across the sampled examples. """ logging.info(f"Calculating dice coefficient for {num_samples} examples " f"from {dataset}...") tic = time.time() dice_coefficient_total = 0. num_examples = 0 sbg = SliceBatchGenerator(input_paths, target_mask_paths, self.FLAGS.batch_size, shape=(self.FLAGS.slice_height, self.FLAGS.slice_width), use_fake_target_masks=self.FLAGS.use_fake_target_masks) for batch in sbg.get_batch(): predicted_masks = self.get_predicted_masks_for_batch(sess, batch) zipped_masks = zip(predicted_masks, batch.target_masks_batch, batch.input_paths_batch, batch.target_mask_path_lists_batch) for idx, (predicted_mask, target_mask, input_path, target_mask_path_list) in enumerate(zipped_masks): dice_coefficient = utils.dice_coefficient(predicted_mask, target_mask) if dice_coefficient >= 0.0: dice_coefficient_total += dice_coefficient num_examples += 1 if print_to_screen: # Whee! We predicted at least one lesion pixel! logging.info(f"Dice coefficient of valid example {num_examples}: " f"{dice_coefficient}") if plot: f, axarr = plt.subplots(1, 2) f.suptitle(input_path) axarr[0].imshow(predicted_mask) axarr[0].set_title("Predicted") axarr[1].imshow(target_mask) axarr[1].set_title("Target") examples_dir = os.path.join(self.FLAGS.train_dir, "examples") if not os.path.exists(examples_dir): os.makedirs(examples_dir) f.savefig(os.path.join(examples_dir, str(num_examples).zfill(4))) if num_samples != None and num_examples >= num_samples: break if num_samples != None and num_examples >= num_samples: break dice_coefficient_mean = dice_coefficient_total / num_examples toc = time.time() logging.info(f"Calculating dice coefficient took {toc-tic} sec.") return dice_coefficient_mean
def calculate_dice_coefficient(self, sess, input_paths, target_mask_paths, dataset, num_samples=100, plot=False, print_to_screen=False): """ Calculates the dice coefficient score for a dataset, represented by a list of {input_paths} and {target_mask_paths}. Inputs: - sess: A TensorFlow Session object. - input_paths: A list of Python strs that represent pathnames to input image files. - target_mask_paths: A list of Python strs that represent pathnames to target mask files. - dataset: A Python str that represents the dataset being tested. Options: {train,dev}. Just for logging purposes. - num_samples: A Python int that represents the number of samples to test. If num_samples=None, then test whole dataset. - plot: A Python bool. If True, plots each example to screen. Outputs: - dice_coefficient: A Python float that represents the average dice coefficient across the sampled examples. """ logging.info(f"Calculating dice coefficient for {num_samples} examples " f"from {dataset}...") tic = time.time() dice_coefficient_total = 0. num_examples = 0 # To be used to save mask sizes for comparison predicted_mask_sizes = [] target_mask_sizes = [] sbg = SliceBatchGenerator(input_paths, target_mask_paths, self.FLAGS.batch_size, shape=(self.FLAGS.slice_height, self.FLAGS.slice_width), use_fake_target_masks=self.FLAGS.use_fake_target_masks) for batch in sbg.get_batch(): predicted_masks = self.get_predicted_masks_for_batch(sess, batch) zipped_masks = zip(predicted_masks, batch.target_masks_batch, batch.input_paths_batch, batch.target_mask_path_lists_batch) for idx, (predicted_mask, target_mask, input_path, target_mask_path_list) in enumerate(zipped_masks): dice_coefficient = utils.dice_coefficient(predicted_mask, target_mask) if dice_coefficient >= 0.0: dice_coefficient_total += dice_coefficient num_examples += 1 if print_to_screen: # Whee! We predicted at least one lesion pixel! logging.info(f"Dice coefficient of valid example {num_examples}: " f"{dice_coefficient}") if plot: if self.FLAGS.mode == 'eval': # Save mask sizes for comparison predicted_mask_sizes.append(np.sum(predicted_mask)) target_mask_sizes.append(np.sum(target_mask)) f, axarr = plt.subplots(1, 2) f.suptitle(input_path) axarr[0].imshow(predicted_mask) axarr[0].set_title("Predicted") axarr[1].imshow(target_mask) axarr[1].set_title("Target") examples_dir = os.path.join(self.FLAGS.train_dir, "examples") if not os.path.exists(examples_dir): os.makedirs(examples_dir) f.savefig(os.path.join(examples_dir, str(num_examples).zfill(4))) if num_samples != None and num_examples >= num_samples: break if num_samples != None and num_examples >= num_samples: break if num_samples < 200 and self.FLAGS.mode == 'eval': predicted_mask_sizes = np.array(predicted_mask_sizes) target_mask_sizes =np.array(target_mask_sizes) args = predicted_mask_sizes.argsort() predicted_mask_sizes = predicted_mask_sizes[args] target_mask_sizes = target_mask_sizes[args] fig, ax = plt.subplots() ind = 2*np.arange(num_examples) rects1 = ax.bar(ind, predicted_mask_sizes, 0.5, color='r') rects2 = ax.bar(ind + 0.5, target_mask_sizes, 0.5, color='b') ax.set_ylabel('Size (in pixels)') ax.set_title('Predicted and Target Mask Sizes') ax.set_xticks(ind + 0.25) ax.set_xticklabels(0.5*ind) ax.legend((rects1[0], rects2[0]), ('Predicted', 'Target')) fig.savefig(os.path.join(self.FLAGS.train_dir, 'relative_sizes')) dice_coefficient_mean = dice_coefficient_total / num_examples toc = time.time() logging.info(f"Calculating dice coefficient took {toc-tic} sec.") return dice_coefficient_mean