# TEST_DATASET_DIR="./dataset/" TEST_FILE = 'test.tfrecords' test_filenames = [str(input_args.data_dir / TEST_FILE)] test_dataset = tf.data.TFRecordDataset(test_filenames) test_dataset = test_dataset.map(tf_record_parser) # Parse the record into tensors. test_dataset = test_dataset.map(scale_image_with_crop_padding) test_dataset = test_dataset.shuffle(buffer_size=100) test_dataset = test_dataset.batch(args.batch_size) iterator = test_dataset.make_one_shot_iterator() batch_images_tf, batch_labels_tf, batch_shapes_tf = iterator.get_next() logits_tf = network.deeplab_v3( batch_images_tf, args, is_training=False, reuse=False, ) valid_labels_batch_tf, valid_logits_batch_tf = ( training.get_valid_logits_and_labels( annotation_batch_tensor=batch_labels_tf, logits_batch_tensor=logits_tf, class_labels=class_labels, ) ) cross_entropies_tf = tf.nn.softmax_cross_entropy_with_logits( logits=valid_logits_batch_tf, labels=valid_labels_batch_tf, )
training_dataset.output_types, training_dataset.output_shapes) batch_images_tf, batch_labels_tf, _ = iterator.get_next() # You can use feedable iterators with a variety of different kinds of iterator # (such as one-shot and initializable iterators). training_iterator = training_dataset.make_initializable_iterator() validation_iterator = validation_dataset.make_initializable_iterator() class_labels = [v for v in range((args.number_of_classes + 1))] class_labels[-1] = 255 is_training_tf = tf.placeholder(tf.bool, shape=[]) logits_tf = tf.cond(is_training_tf, true_fn=lambda: network.deeplab_v3( batch_images_tf, args, is_training=True, reuse=False), false_fn=lambda: network.deeplab_v3( batch_images_tf, args, is_training=False, reuse=True)) # get valid logits and labels (factor the 255 padded mask out for cross entropy) valid_labels_batch_tf, valid_logits_batch_tf = training.get_valid_logits_and_labels( annotation_batch_tensor=batch_labels_tf, logits_batch_tensor=logits_tf, class_labels=class_labels) cross_entropies = tf.nn.softmax_cross_entropy_with_logits_v2( logits=valid_logits_batch_tf, labels=valid_labels_batch_tf) cross_entropy_tf = tf.reduce_mean(cross_entropies) predictions_tf = tf.argmax(logits_tf, axis=3) tf.summary.scalar('cross_entropy', cross_entropy_tf)
def main(_): with tf.Session(graph=tf.Graph()) as sess: # define placeholders for receiving the input image height and width image_height_tensor = tf.placeholder(tf.int32) image_width_tensor = tf.placeholder(tf.int32) # placeholder for receiving the serialized input image serialized_tf_example = tf.placeholder(tf.string, name='tf_example') feature_configs = { 'x': tf.FixedLenFeature(shape=[], dtype=tf.float32), } tf_example = tf.parse_example(serialized_tf_example, feature_configs) # reshape the input image to its original dimension tf_example['x'] = tf.reshape( tf_example['x'], (1, image_height_tensor, image_width_tensor, 3)) input_tensor = tf.identity( tf_example['x'], name='x') # use tf.identity() to assign name # perform inference on the input image logits_tf = network.deeplab_v3(input_tensor, args, is_training=False, reuse=False) # extract the segmentation mask predictions_tf = tf.argmax(logits_tf, axis=3) # specify the directory where the pre-trained model weights are stored pre_trained_model_dir = os.path.join(log_folder, model_name, "train") saver = tf.train.Saver() # Restore variables from disk. saver.restore(sess, os.path.join(pre_trained_model_dir, "model.ckpt")) print("Model", model_name, "restored.") # Create SavedModelBuilder class # defines where the model will be exported export_path_base = FLAGS.export_model_dir export_path = os.path.join( tf.compat.as_bytes(export_path_base), tf.compat.as_bytes(str(FLAGS.model_version))) print('Exporting trained model to', export_path) builder = tf.saved_model.builder.SavedModelBuilder(export_path) # Creates the TensorInfo protobuf objects that encapsulates the input/output tensors tensor_info_input = tf.saved_model.utils.build_tensor_info( input_tensor) tensor_info_height = tf.saved_model.utils.build_tensor_info( image_height_tensor) tensor_info_width = tf.saved_model.utils.build_tensor_info( image_width_tensor) # output tensor info tensor_info_output = tf.saved_model.utils.build_tensor_info( predictions_tf) # Defines the DeepLab signatures, uses the TF Predict API # It receives an image and its dimensions and output the segmentation mask prediction_signature = ( tf.saved_model.signature_def_utils.build_signature_def( inputs={ 'images': tensor_info_input, 'height': tensor_info_height, 'width': tensor_info_width }, outputs={'segmentation_map': tensor_info_output}, method_name=tf.saved_model.signature_constants. PREDICT_METHOD_NAME)) builder.add_meta_graph_and_variables( sess, [tf.saved_model.tag_constants.SERVING], signature_def_map={ 'predict_images': prediction_signature, }) # export the model builder.save(as_text=True) print('Done exporting!')
def infer(self, image_dir, batch_size, ckpt, output_folder, organ): """ Uses a trained model file to get predictions on the specified images. Args: image_dir: Directory where the images are located. batch_size: Batch size to use while inferring (relevant if batch norm is used) ckpt: Name of the checkpoint file to use. output_folder: Folder where the predictions on the images shoudl be saved. """ print(image_dir) image_paths = [ os.path.join(image_dir, x) for x in os.listdir(image_dir) if x.endswith('.png') or x.endswith('.jpg') ] infer_data, infer_queue_init = utility.data_batch(image_paths, None, batch_size, shuffle=False) image_ph = None mask = None #training = None if not self.restored: image_ph = tf.placeholder(tf.float32, shape=[None, 256, 256, 1]) #training = tf.placeholder(tf.bool, shape=[]) if self.logits == None: self.logits = network.deeplab_v3(image_ph, is_training=True, reuse=tf.AUTO_REUSE) #self.logits = self.model(image_ph, training) if not self.restored: mask = tf.squeeze(tf.argmax(self.logits, axis=3)) j = 0 saver = tf.train.Saver() with tf.Session() as sess: if not self.restored: saver.restore(sess, ckpt) restored = True sess.run(infer_queue_init) for _ in range(len(image_paths) // batch_size): image = sess.run(infer_data) #print(np.max(image), np.min(image), image.shape, image.dtype) feed_dict = { image_ph: image #training: False } prediction = sess.run(mask, feed_dict) """ if len(prediction.shape) >= 3: for j in range(prediction.shape[0]): print(prediction[j].shape) cv2.imwrite(os.path.join(output_folder, '{}.png'.format(j)), 255 * prediction[j, :, :]) else: """ #print(prediction.shape) prediction = binary_fill_holes( resize((255 * prediction[:, :]).astype(np.uint8), (512, 512))) * 255 if organ == "spleen": prediction[:, 0:256] = 0 cv2.imwrite( os.path.join(output_folder, os.path.basename(image_paths[_])), prediction) print( os.path.join(output_folder, os.path.basename(image_paths[_]))) remove_small_objects(prediction, min_size=50, connectivity=1, in_place=True) image = resize((255 * image[0, :, :, 0]).astype(np.uint8), (512, 512)) * 255 mean3_image = convolve2d(image, np.ones((3, 5)) * 1 / 9, mode='same', boundary='symm') mean5_image = convolve2d(image, np.ones((5, 5)) * 1 / 25, mode='same', boundary='symm') segstd = np.std(image[prediction == 255]) segmean = np.mean(image[prediction == 255]) """ Visualization Skip Purpose if np.mean(prediction) < 1: continue """ mean3_segstd = np.std(mean3_image[prediction == 255]) mean3_segmean = np.mean(mean3_image[prediction == 255]) mean5_segstd = np.std(mean5_image[prediction == 255]) mean5_segmean = np.mean(mean5_image[prediction == 255]) print(segstd, segmean) queue_visit = [] visited = copy.deepcopy(prediction) visit_depth = 20 # Region Growing 'queue_visit initialization' for x in range(1, 511): for y in range(1, 511): if visited[x, y] == 255: if visited[x + 1, y] == 0: visited[x + 1, y] = 1 queue_visit.append((x + 1, y)) if visited[x, y + 1] == 0: visited[x, y + 1] = 1 queue_visit.append((x, y + 1)) if visited[x - 1, y] == 0: visited[x - 1, y] = 1 queue_visit.append((x - 1, y)) if visited[x, y - 1] == 0: visited[x, y - 1] = 1 queue_visit.append((x, y - 1)) while (len(queue_visit) and visit_depth): x, y = queue_visit.pop(0) # Check the tissue value is eligible if mean3_image[ x, y] > mean3_segmean + 0.6 * mean3_segstd or mean3_image[ x, y] < mean3_segmean - 0.6 * mean3_segstd: visited[x, y] = -1 else: visited[x, y] = 255 # Traverse neighborhood pixels if visited[x + 1, y] == 0: visited[x + 1, y] = 1 queue_visit.append((x + 1, y)) if visited[x, y + 1] == 0: visited[x, y + 1] = 1 queue_visit.append((x, y + 1)) if visited[x - 1, y] == 0: visited[x - 1, y] = 1 queue_visit.append((x - 1, y)) if visited[x, y - 1] == 0: visited[x, y - 1] = 1 queue_visit.append((x, y - 1)) visit_depth -= 1 visited[visited < 0] = 0 prediction = visited """ fig = plt.figure() ax_o = fig.add_subplot(133) ax_o.set_title("Original") #im = ax.imshow(prediction, cmap='gray') im_5mean = ax_5mean.imshow(prediction, cmap='gray') im_pred = ax_pred.imshow(prediction, cmap='gray') im_o = ax_o.imshow(image, cmap='gray') overlay_tmp = skimage.color.grey2rgb(prediction, alpha=True) overlay_tmp[:,:,3] = 0.7 im_o_ol = ax_o.imshow(overlay_tmp) axstd = fig.add_axes([0.1,0.08,0.65,0.03], facecolor='lightgoldenrodyellow') sstd = Slider(axstd, 'STD', 0.0, 10.0, valinit=1, valstep=0.05) axtp_factor = fig.add_axes([0.1,0.03,0.65,0.03], facecolor='lightgoldenrodyellow') stp_factor = Slider(axtp_factor, 'TP_FACTOR', 0.0, 50.0, valinit=1, valstep=0.05) def update(val): tmp_img = copy.deepcopy(prediction) tmp_img_mean3 = copy.deepcopy(prediction) tmp_img_mean5 = copy.deepcopy(prediction) tmp_img_mean5_growing = copy.deepcopy(prediction) alpha = sstd.val tp_factor=stp_factor.val min_targ = np.minimum.reduce(np.where(tmp_img > 0)) max_targ = np.maximum.reduce(np.where(tmp_img > 0)) for _ in range(4): for x in range(min_targ[0] - 3, max_targ[0] + 2): for y in range(min_targ[1] - 3, max_targ[1] + 2): if image[x, y] < segmean + alpha * segstd and image[x,y] > segmean - alpha * segstd and tmp_img[x, y] == 0 and (tmp_img[x - 1, y] != 0 or tmp_img[x + 1, y] != 0 or tmp_img[x, y - 1] != 0 or tmp_img[x, y + 1] != 0): print("Find!", _, x, y) tmp_img[x,y] = 255 #im_rg.set_data(tmp_img) tmp_img[image > segmean + alpha * segstd] = 0 tmp_img[image < segmean - alpha * segstd] = 0 #im.set_data(tmp_img) tmp_img_mean3[mean3_image > mean3_segmean + alpha * mean3_segstd] = 0 tmp_img_mean3[mean3_image < mean3_segmean - alpha * mean3_segstd] = 0 #im_3mean.set_data(tmp_img_mean3) tmp_img_mean5[mean5_image > mean5_segmean + alpha * mean5_segstd] = 0 tmp_img_mean5[mean5_image < mean5_segmean - alpha * mean5_segstd] = 0 im_5mean.set_data(tmp_img_mean5) im_o.set_data(image + tmp_img_mean5/tp_factor) sstd.on_changed(update) stp_factor.on_changed(update) plt.show() """ tmp_img_mean5 = copy.deepcopy(prediction) tmp_img_mean5[mean5_image > mean5_segmean + 1.5 * mean5_segstd] = 0 tmp_img_mean5[mean5_image < mean5_segmean - 1.5 * mean5_segstd] = 0 prediction = tmp_img_mean5 cv2.imwrite( os.path.join(output_folder, os.path.basename(image_paths[_])), prediction) j = j + 1
def train(self, train_path, val_path, save_dir, batch_size, epochs, learning_rate): """ Trains the Tiramisu on the specified training data and periodically validates on the validation data. Args: train_path: Directory where the training data is present. val_path: Directory where the validation data is present. save_dir: Directory where to save the model and training summaries. batch_size: Batch size to use for training. epochs: Number of epochs (complete passes over one dataset) to train for. learning_rate: Learning rate for the optimizer. Returns: None """ tf.logging.set_verbosity(tf.logging.INFO) train_image_path = os.path.join(train_path, 'images') train_mask_path = os.path.join(train_path, 'masks_old') val_image_path = os.path.join(val_path, 'images') val_mask_path = os.path.join(val_path, 'masks_old') assert os.path.exists( train_image_path), "No training image folder found" assert os.path.exists(train_mask_path), "No training mask folder found" assert os.path.exists( val_image_path), "No validation image folder found" assert os.path.exists(val_mask_path), "No validation mask folder found" train_image_paths, train_mask_paths = get_data_paths_list( train_image_path, train_mask_path) val_image_paths, val_mask_paths = get_data_paths_list( val_image_path, val_mask_path) assert len(train_image_paths) == len( train_mask_paths ), "Number of images and masks dont match in train folder" assert len(val_image_paths) == len( val_mask_paths ), "Number of images and masks dont match in validation folder" self.num_train_images = len(train_image_paths) self.num_val_images = len(val_image_paths) print("Loading Data") train_data, train_queue_init = utility.data_batch( train_image_paths, train_mask_paths, batch_size) train_image_tensor, train_mask_tensor = train_data eval_data, eval_queue_init = utility.data_batch( val_image_paths, val_mask_paths, batch_size) eval_image_tensor, eval_mask_tensor = eval_data print("Loading Data Finished") image_ph = tf.placeholder(tf.float32, shape=[None, 256, 256, 1]) mask_ph = tf.placeholder(tf.int32, shape=[None, 256, 256, 1]) training = tf.placeholder(tf.bool, shape=[]) if self.logits == None: self.logits = network.deeplab_v3(image_ph, is_training=True, reuse=False) #slef.logits = self.model(image_ph, training) loss = tf.reduce_mean(self.xentropy_loss(self.logits, mask_ph)) with tf.variable_scope("mean_iou_train"): iou, iou_update = self.calculate_iou(mask_ph, self.logits) merged = tf.summary.merge_all() optimizer = tf.train.AdamOptimizer(learning_rate) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): opt = optimizer.minimize(loss) running_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="mean_iou_train") reset_iou = tf.variables_initializer(var_list=running_vars) saver = tf.train.Saver(max_to_keep=30) config = tf.ConfigProto() config.gpu_options.allow_growth = True print(self.num_train_images) with tf.Session(config=config) as sess: print("Initializing Variables") sess.run([ tf.global_variables_initializer(), tf.local_variables_initializer() ]) print("Initializing Variables Finished") saver.restore(sess, tf.train.latest_checkpoint(save_dir)) print("Checkpoint restored") for epoch in range(epochs): print("Epoch: ", epoch) writer = tf.summary.FileWriter(os.path.dirname(save_dir), sess.graph) print("Epoch queue init start") sess.run([train_queue_init, eval_queue_init]) print("Epoch queue init ends") total_train_cost, total_val_cost = 0, 0 total_train_iou, total_val_iou = 0, 0 for train_step in range((13 * self.num_train_images) // batch_size - 1): image_batch, mask_batch, _ = sess.run( [train_image_tensor, train_mask_tensor, reset_iou]) #print(np.max(image_batch), np.min(image_batch)) feed_dict = { image_ph: image_batch, mask_ph: mask_batch, training: True } cost, _, _, summary = sess.run( [loss, opt, iou_update, merged], feed_dict=feed_dict) train_iou = sess.run(iou, feed_dict=feed_dict) total_train_cost += cost total_train_iou += train_iou writer.add_summary(summary, train_step) if train_step % 50 == 0: print("Step: ", train_step, "Cost: ", cost, "IoU:", train_iou) for val_step in range(self.num_val_images // batch_size): image_batch, mask_batch, _ = sess.run( [eval_image_tensor, eval_mask_tensor, reset_iou]) feed_dict = { image_ph: image_batch, mask_ph: mask_batch, training: True } eval_cost, _ = sess.run([loss, iou_update], feed_dict=feed_dict) eval_iou = sess.run(iou, feed_dict=feed_dict) total_val_cost += eval_cost total_val_iou += eval_iou print("Epoch: {0}, training loss: {1}, validation loss: {2}". format(epoch, total_train_cost / train_step, total_val_cost / val_step)) print("Epoch: {0}, training iou: {1}, val iou: {2}".format( epoch, total_train_iou / train_step, total_val_iou / val_step)) print("Saving model...") saver.save(sess, save_dir, global_step=epoch)
input_tensor = tf.image.decode_png(input_bytes, channels=3) input_tensor = tf.cast(input_tensor, tf.float32) # Model's inference function accepts a batch of images # So expand the single tensor into a batch of 1 input_tensor = tf.expand_dims(input_tensor, 0) # Resize the input tensor to a tertain size #input_tensor = tf.image.resize_bilinear( # input_tensor, [FLAGS.image_size, FLAGS.image_size], align_corners=False) # Then, we feed the tensor to the model and save its output. with graph.as_default(): # Get model predictions logits_tf = network.deeplab_v3(input_tensor, args, is_training=False, reuse=False) # extract the segmentation classes, each value refer to a class predictions_tf = tf.argmax(logits_tf, axis=3) # int64 with graph.as_default(): # Cast the output to uint8 output_tensor = tf.cast(predictions_tf, tf.uint8) # Remove the batch dimension output_tensor = tf.squeeze(output_tensor, 0) ## Stack the tensor to (?, ?, 3) for image encoding #output_tensor = tf.stack([output_tensor,output_tensor,output_tensor], 2) output_tensor = tf.expand_dims(output_tensor, -1)
def infer(self, image_dir, batch_size, ckpt, output_folder, organ): """ Uses a trained model file to get predictions on the specified images. Args: image_dir: Directory where the images are located. batch_size: Batch size to use while inferring (relevant if batch norm is used) ckpt: Name of the checkpoint file to use. output_folder: Folder where the predictions on the images shoudl be saved. """ print(image_dir) cvt_dcm_png(image_dir, "CT") image_paths = [os.path.join(image_dir, x) for x in os.listdir(image_dir) if x.endswith('.png') or x.endswith('.jpg')] infer_data, infer_queue_init = utility.data_batch( image_paths, None, batch_size, shuffle=False) image_ph = None mask = None #training = None if not self.restored: image_ph = tf.placeholder(tf.float32, shape=[None, 256, 256, 1]) #training = tf.placeholder(tf.bool, shape=[]) if self.logits == None: self.logits = network.deeplab_v3(image_ph, is_training=True, reuse=tf.AUTO_REUSE) #self.logits = self.model(image_ph, training) if not self.restored: mask = tf.squeeze(tf.argmax(self.logits, axis=3)) j = 0 saver = tf.train.Saver() with tf.Session() as sess: if not self.restored: saver.restore(sess, ckpt) restored = True sess.run(infer_queue_init) for _ in range(len(image_paths) // batch_size): image = sess.run(infer_data) #print(np.max(image), np.min(image), image.shape, image.dtype) feed_dict = { image_ph: image #training: False } prediction = sess.run(mask, feed_dict) prediction = remove_small_objects(prediction.astype(np.bool), min_size=450, connectivity=2) * 255 #cv2.imwrite(os.path.join(output_folder, "raw_" + os.path.basename(image_paths[_])), prediction) prediction = binary_fill_holes(resize((255 * prediction[:, :]).astype(np.uint8), (512,512)))* 255 if organ=="spleen": prediction[:, 0:250] = 0 print(os.path.join(output_folder, os.path.basename(image_paths[_]))) image = resize((255 * image[0,:, :,0]).astype(np.uint8), (512,512))* 255 mean3_image = convolve2d(image, np.ones((3,3)) * 1/9, mode='same', boundary='symm') mean5_image = convolve2d(image, np.ones((5,5)) * 1/25, mode='same', boundary='symm') segstd = np.std(image[prediction == 255]) segmean = np.mean(image[prediction == 255]) """ Visualization Skip Purpose if np.mean(prediction) < 1: continue """ mean3_segstd = np.std(mean3_image[prediction == 255]) mean3_segmean = np.mean(mean3_image[prediction == 255]) mean5_segstd = np.std(mean5_image[prediction == 255]) mean5_segmean = np.mean(mean5_image[prediction == 255]) print(segstd, segmean) segment = 1800 compact = 100 std = 1.2 # 1.2 std_rm = 1.85 #1.85 std_slic = 1.0 # 1.0 queue_visit = [] visited = copy.deepcopy(prediction) visit_depth = 150 image_slic = slic(image, n_segments=segment, compactness=compact, sigma=1) liver_slic = defaultdict(lambda: 0) # Region Growing 'queue_visit initialization' for x in range(1, 510): for y in range(1, 511): if visited[x, y] == 255: liver_slic[image_slic[x, y]] += 1 if visited[x + 1, y] == 0: visited[x + 1, y] = 1 queue_visit.append((x + 1, y)) if visited[x, y + 1] == 0: visited[x, y + 1] = 1 queue_visit.append((x, y + 1)) if visited[x - 1, y] == 0: visited[x - 1, y] = 1 queue_visit.append((x - 1, y)) if visited[x, y - 1] == 0: visited[x, y - 1] = 1 queue_visit.append((x, y - 1)) while(len(queue_visit)): if visited[x, y] == visit_depth: break x, y = queue_visit.pop(0) #print("Target Pixel", (x, y), image[x, y], liver_slic[image_slic[x, y]]) # Check the tissue value is eligible and liver_slic[image_slic[x,y]] < 300 m = np.mean(image[image_slic == image_slic[x,y]]) if image[x, y] > segmean + std * segstd or image[x, y] < segmean - std * segstd or m < segmean: visited[x, y] = 1 else: # Traverse neighborhood pixels if x + 1 < 512 and visited[x + 1, y] == 0: visited[x + 1, y] = visited[x, y] + 1 queue_visit.append((x + 1, y)) if y + 1 < 512 and visited[x, y + 1] == 0: visited[x, y + 1] = visited[x, y] + 1 queue_visit.append((x, y + 1)) if x - 1 > 0 and visited[x - 1, y] == 0: visited[x - 1, y] = visited[x, y] + 1 queue_visit.append((x - 1, y)) if y - 1 > 0 and visited[x, y - 1] == 0: visited[x, y - 1] = visited[x, y] + 1 queue_visit.append((x, y - 1)) visited[x, y] = 255 tmp_img_mean5 = copy.deepcopy(visited) tmp_img_mean5[mean5_image > mean5_segmean + std_rm * mean5_segstd] = 0 tmp_img_mean5[mean5_image < mean5_segmean - std_rm * mean5_segstd] = 0 tmp_img_mean5 = binary_opening(tmp_img_mean5, selem=disk(5)).astype(np.uint8) * 255 remove_small_objects(tmp_img_mean5, min_size=400, connectivity=1, in_place=True) cv2.imwrite(os.path.join(output_folder, os.path.basename(image_paths[_])), tmp_img_mean5) j = j + 1
handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle( handle, training_dataset.output_types, training_dataset.output_shapes) batch_images_tf, batch_labels_tf, _ = iterator.get_next() training_iterator = training_dataset.make_initializable_iterator() validation_iterator = validation_dataset.make_initializable_iterator() class_labels = [v for v in range((args.number_of_classes+1))] class_labels[-1] = 255 is_training_tf = tf.placeholder(tf.bool, shape=[]) logits_tf, small_logits_tf = tf.cond(is_training_tf, true_fn= lambda: network.deeplab_v3(batch_images_tf, args, is_training=True, reuse=False), false_fn=lambda: network.deeplab_v3(batch_images_tf, args, is_training=False, reuse=True)) small_logits_size_tf = tf.shape(small_logits_tf)[1:3] batch_labels_tf_dims = tf.expand_dims(batch_labels_tf, axis=3) small_batch_labels_tf_dims = tf.image.resize_nearest_neighbor(batch_labels_tf_dims, small_logits_size_tf, name='label_downsample_x8') small_batch_labels_tf = tf.squeeze(small_batch_labels_tf_dims,axis=3) valid_labels_batch_tf, valid_logits_batch_tf = training.get_valid_logits_and_labels( annotation_batch_tensor=batch_labels_tf, logits_batch_tensor=logits_tf, class_labels=class_labels) valid_small_labels_batch_tf, valid_small_logits_batch_tf = training.get_valid_logits_and_labels( annotation_batch_tensor=small_batch_labels_tf, logits_batch_tensor=small_logits_tf,