def model_fn(features, labels, mode, params): """Defines the model compatible with tf.estimator.""" del labels, params if mode == tf.estimator.ModeKeys.TRAIN: _build_deeplab( features, model.get_output_to_num_classes(FLAGS), model_input.dataset_descriptors[FLAGS.dataset].ignore_label) # Print out the objective loss and regularization loss independently to # track NaN loss issue objective_losses = tf.losses.get_losses() objective_losses = tf.Print(objective_losses, [objective_losses], message='Objective Losses: ', summarize=100) objective_loss = tf.reduce_sum(objective_losses) tf.summary.scalar('objective_loss', objective_loss) reg_losses = tf.losses.get_regularization_losses() reg_losses = tf.Print(reg_losses, [reg_losses], message='Reg Losses: ', summarize=100) reg_loss = tf.reduce_sum(reg_losses) tf.summary.scalar('regularization_loss', reg_loss) loss = objective_loss + reg_loss learning_rate = train_utils.get_model_learning_rate( FLAGS.learning_policy, FLAGS.base_learning_rate, FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor, FLAGS.training_number_of_steps, FLAGS.learning_power, FLAGS.slow_start_step, FLAGS.slow_start_learning_rate) optimizer = tf.train.AdamOptimizer(learning_rate) tf.summary.scalar('learning_rate', learning_rate) grads_and_vars = optimizer.compute_gradients(loss) grad_updates = optimizer.apply_gradients(grads_and_vars, tf.train.get_global_step()) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_op = tf.identity(loss, name='train_op') return tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, )
def main(unused_argv): FLAGS.comb_dropout_keep_prob = 1.0 FLAGS.image_keep_prob = 1.0 FLAGS.elements_keep_prob = 1.0 # Get dataset-dependent information. tf.gfile.MakeDirs(FLAGS.eval_logdir) tf.logging.info('Evaluating on %s set', FLAGS.split) with tf.Graph().as_default(): samples = model_input.get_input_fn(FLAGS)() # Get model segmentation predictions. num_classes = model_input.dataset_descriptors[ FLAGS.dataset].num_classes output_to_num_classes = model.get_output_to_num_classes(FLAGS) if tuple(FLAGS.eval_scales) == (1.0, ): tf.logging.info('Performing single-scale test.') predictions, probs = model.predict_labels( samples['image'], samples, FLAGS, outputs_to_num_classes=output_to_num_classes, image_pyramid=FLAGS.image_pyramid, merge_method=FLAGS.merge_method, atrous_rates=FLAGS.atrous_rates, add_image_level_feature=FLAGS.add_image_level_feature, aspp_with_batch_norm=FLAGS.aspp_with_batch_norm, aspp_with_separable_conv=FLAGS.aspp_with_separable_conv, multi_grid=FLAGS.multi_grid, depth_multiplier=FLAGS.depth_multiplier, output_stride=FLAGS.output_stride, decoder_output_stride=FLAGS.decoder_output_stride, decoder_use_separable_conv=FLAGS.decoder_use_separable_conv, crop_size=[FLAGS.image_size, FLAGS.image_size], logits_kernel_size=FLAGS.logits_kernel_size, model_variant=FLAGS.model_variant) else: tf.logging.info('Performing multi-scale test.') predictions, probs = model.predict_labels_multi_scale( samples['image'], samples, FLAGS, outputs_to_num_classes=output_to_num_classes, eval_scales=FLAGS.eval_scales, add_flipped_images=FLAGS.add_flipped_images, merge_method=FLAGS.merge_method, atrous_rates=FLAGS.atrous_rates, add_image_level_feature=FLAGS.add_image_level_feature, aspp_with_batch_norm=FLAGS.aspp_with_batch_norm, aspp_with_separable_conv=FLAGS.aspp_with_separable_conv, multi_grid=FLAGS.multi_grid, depth_multiplier=FLAGS.depth_multiplier, output_stride=FLAGS.output_stride, decoder_output_stride=FLAGS.decoder_output_stride, decoder_use_separable_conv=FLAGS.decoder_use_separable_conv, crop_size=[FLAGS.image_size, FLAGS.image_size], logits_kernel_size=FLAGS.logits_kernel_size, model_variant=FLAGS.model_variant) metric_map = {} for output in output_to_num_classes: output_predictions = predictions[output] output_probs = probs[output] if output == 'segment': output_predictions = tf.expand_dims(output_predictions, 3) if num_classes == 2: labels = samples['label'] iou, weights = model.foreground_iou( labels, output_predictions, FLAGS) soft_iou, _ = model.foreground_iou( labels, output_probs[:, :, :, 1:2], FLAGS) metric_map['mIOU'] = tf.metrics.mean(iou) metric_map['soft_mIOU'] = tf.metrics.mean(soft_iou) high_prob_overlaps = calc_high_prob_overlaps( labels, output_probs, weights) metric_map['highestOverlaps'] = tf.metrics.mean( high_prob_overlaps) output_probs *= weights else: output_predictions = tf.reshape(output_predictions, shape=[-1]) labels = tf.reshape(samples['label'], shape=[-1]) weights = tf.to_float( tf.not_equal( labels, model_input.dataset_descriptors[ FLAGS.dataset].ignore_label)) # Set ignore_label regions to label 0, because metrics.mean_iou # requires range of labels=[0, dataset.num_classes). # Note the ignore_label regions are not evaluated since # the corresponding regions contain weights=0. labels = tf.where( tf.equal( labels, model_input.dataset_descriptors[ FLAGS.dataset].ignore_label), tf.zeros_like(labels), labels) predictions_tag = 'mIOU' for eval_scale in FLAGS.eval_scales: predictions_tag += '_' + str(eval_scale) if FLAGS.add_flipped_images: predictions_tag += '_flipped' # Define the evaluation metric. metric_map[predictions_tag] = slim.metrics.mean_iou( output_predictions, labels, num_classes, weights=weights) def label_summary(labels, weights, name): tf.summary.image( name, tf.reshape( tf.cast( tf.to_float(labels * 255) / tf.to_float(num_classes), tf.uint8) * tf.cast(weights, tf.uint8), [-1, FLAGS.image_size, FLAGS.image_size, 1]), 8) label_summary(labels, weights, 'label') label_summary(output_predictions, weights, 'output_predictions') tf.summary.image('logits', tf.expand_dims(output_probs[:, :, :, 1], 3)) elif output == 'regression': labels = samples['label'] ignore_mask = model.get_ignore_mask(labels, FLAGS) accurate = calc_accuracy_in_box(labels, output_probs, ignore_mask) metric_map['inBoxAccuracy'] = tf.metrics.mean(accurate) tf.summary.image('image', samples['image'], 8) metrics_to_values, metrics_to_updates = slim.metrics.aggregate_metric_map( metric_map) for metric_name, metric_value in metrics_to_values.iteritems(): metric_value = tf.Print(metric_value, [metric_value], metric_name) tf.summary.scalar(metric_name, metric_value) num_batches = int( math.ceil(FLAGS.num_samples / float(FLAGS.batch_size))) tf.logging.info('Eval num images %d', FLAGS.num_samples) tf.logging.info('Eval batch size %d and num batch %d', FLAGS.batch_size, num_batches) slim.evaluation.evaluation_loop( master='', checkpoint_dir=FLAGS.checkpoint_dir, logdir=FLAGS.eval_logdir, num_evals=num_batches, eval_op=metrics_to_updates.values(), summary_op=tf.summary.merge_all(), max_number_of_evaluations=None, eval_interval_secs=FLAGS.eval_interval_secs)
def main(unused_argv): # Get dataset-dependent information. # Prepare for visualization. tf.gfile.MakeDirs(FLAGS.vis_logdir) save_dir = os.path.join(FLAGS.vis_logdir, _SEMANTIC_PREDICTION_SAVE_FOLDER) tf.gfile.MakeDirs(save_dir) raw_save_dir = os.path.join(FLAGS.vis_logdir, _RAW_SEMANTIC_PREDICTION_SAVE_FOLDER) tf.gfile.MakeDirs(raw_save_dir) num_vis_examples = FLAGS.num_vis_examples print('Visualizing on set', FLAGS.split) g = tf.Graph() with g.as_default(): samples = model_input.get_input_fn(FLAGS)() outputs_to_num_classes = model.get_output_to_num_classes(FLAGS) # Get model segmentation predictions. if tuple(FLAGS.eval_scales) == (1.0, ): tf.logging.info('Performing single-scale test.') predictions, probs = model.predict_labels( samples['image'], samples, FLAGS, outputs_to_num_classes=outputs_to_num_classes, image_pyramid=FLAGS.image_pyramid, merge_method=FLAGS.merge_method, atrous_rates=FLAGS.atrous_rates, add_image_level_feature=FLAGS.add_image_level_feature, aspp_with_batch_norm=FLAGS.aspp_with_batch_norm, aspp_with_separable_conv=FLAGS.aspp_with_separable_conv, multi_grid=FLAGS.multi_grid, depth_multiplier=FLAGS.depth_multiplier, output_stride=FLAGS.output_stride, decoder_output_stride=FLAGS.decoder_output_stride, decoder_use_separable_conv=FLAGS.decoder_use_separable_conv, crop_size=[FLAGS.image_size, FLAGS.image_size], logits_kernel_size=FLAGS.logits_kernel_size, model_variant=FLAGS.model_variant) else: tf.logging.info('Performing multi-scale test.') predictions, probs = model.predict_labels_multi_scale( samples['image'], samples, FLAGS, outputs_to_num_classes=outputs_to_num_classes, eval_scales=FLAGS.eval_scales, add_flipped_images=FLAGS.add_flipped_images, merge_method=FLAGS.merge_method, atrous_rates=FLAGS.atrous_rates, add_image_level_feature=FLAGS.add_image_level_feature, aspp_with_batch_norm=FLAGS.aspp_with_batch_norm, aspp_with_separable_conv=FLAGS.aspp_with_separable_conv, multi_grid=FLAGS.multi_grid, depth_multiplier=FLAGS.depth_multiplier, output_stride=FLAGS.output_stride, decoder_output_stride=FLAGS.decoder_output_stride, decoder_use_separable_conv=FLAGS.decoder_use_separable_conv, crop_size=[FLAGS.image_size, FLAGS.image_size], logits_kernel_size=FLAGS.logits_kernel_size, model_variant=FLAGS.model_variant) if FLAGS.output_mode == 'segment': predictions = tf.squeeze( tf.cast(predictions[FLAGS.output_mode], tf.int32)) probs = probs[FLAGS.output_mode] labels = tf.squeeze(tf.cast(samples['label'], tf.int32)) weights = tf.cast( tf.not_equal( labels, model_input.dataset_descriptors[ FLAGS.dataset].ignore_label), tf.int32) labels *= weights predictions *= weights tf.train.get_or_create_global_step() saver = tf.train.Saver(contrib_slim.get_variables_to_restore()) sv = tf.train.Supervisor(graph=g, logdir=FLAGS.vis_logdir, init_op=tf.global_variables_initializer(), summary_op=None, summary_writer=None, global_step=None, saver=saver) num_batches = int( math.ceil(num_vis_examples / float(FLAGS.batch_size))) last_checkpoint = None # Infinite loop to visualize the results when new checkpoint is created. while True: last_checkpoint = contrib_slim.evaluation.wait_for_new_checkpoint( FLAGS.checkpoint_dir, last_checkpoint) start = time.time() print('Starting visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime())) print('Visualizing with model %s', last_checkpoint) print('Visualizing with model ', last_checkpoint) with sv.managed_session(FLAGS.master, start_standard_services=False) as sess: # sv.start_queue_runners(sess) sv.saver.restore(sess, last_checkpoint) image_id_offset = 0 refs = [] for batch in range(num_batches): print('Visualizing batch', batch + 1, num_batches) refs.extend( _process_batch(sess=sess, samples=samples, semantic_predictions=predictions, labels=labels, image_id_offset=image_id_offset, save_dir=save_dir)) image_id_offset += FLAGS.batch_size print('Finished visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime())) time_to_next_eval = start + FLAGS.eval_interval_secs - time.time() if time_to_next_eval > 0: time.sleep(time_to_next_eval)