def produce_saliency_map(self, data_path, writer):
        """produces a saliency map."""

        self._dataset = DataIterator(data_path,
                                     self._dataset_name,
                                     preprocessing=False,
                                     test_small_sample=FLAGS.test_small_sample)

        self._graph = tf.Graph()
        with self._graph.as_default():
            image_raw, image_processed, label = self._dataset.input_fn()

            image_processed -= tf.constant(MEAN_RGB,
                                           shape=[1, 1, 3],
                                           dtype=image_processed.dtype)
            image_processed /= tf.constant(STDDEV_RGB,
                                           shape=[1, 1, 3],
                                           dtype=image_processed.dtype)

            network = resnet_model.resnet_50(
                num_classes=self._num_label_classes,
                data_format='channels_last',
            )

            logits = network(inputs=image_processed, is_training=False)

            prediction = tf.cast(tf.argmax(logits, axis=1), tf.int32)

            self._neuron_selector = tf.placeholder(tf.int32)

            y = logits[0][self._neuron_selector]

            self._sess = tf.Session(graph=self._graph)
            saver = tf.train.Saver()

            saver.restore(self._sess, self._ckpt_directory)

            self._gradient_placeholder = get_saliency_image(
                self._graph, self._sess, y, image_processed, 'gradient')
            self._back_prop_placeholder = get_saliency_image(
                self._graph, self._sess, y, image_processed, 'guided_backprop')
            self._integrated_gradient_placeholder = get_saliency_image(
                self._graph, self._sess, y, image_processed,
                'integrated_gradients')

            baseline = SALIENCY_BASELINE['resnet_50']

            self._coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=self._sess,
                                                   coord=self._coord)

            example_count = 0
            try:
                while True:
                    img_out, raw_img_out, label_out, prediction_out = self._sess.run(
                        [image_processed, image_raw, label, prediction])
                    if img_out.shape[3] == 3:
                        img_out = np.squeeze(img_out)

                        feed_dict = {self._neuron_selector: prediction_out[0]}
                        if self._saliency_method != 'SOBEL':
                            saliency_map = generate_saliency_image(
                                self._saliency_method, img_out, feed_dict,
                                self._gradient_placeholder,
                                self._back_prop_placeholder,
                                self._integrated_gradient_placeholder,
                                baseline)
                        else:
                            saliency_map = ndimage.sobel(img_out, axis=0)

                    saliency_map = saliency_map.astype(np.float32)
                    saliency_map = np.reshape(saliency_map, [-1])
                    example = image_to_tfexample(raw_image=raw_img_out[0],
                                                 maps=saliency_map,
                                                 label=label_out)
                    writer.write(example.SerializeToString())
                    example_count += 1

                    if FLAGS.test_small_sample:
                        if example_count == 2:
                            break

            except tf.errors.OutOfRangeError:
                print('Finished number of images:', example_count)
            finally:
                self._coord.request_stop()
                self._coord.join(threads)
                writer.close()
Exemplo n.º 2
0
def resnet_model_fn(features, labels, mode, params):
    """Setup of training and eval for modified dataset using a ResNet-50.

  Args:
    features: A float32 batch of images.
    labels: A int32 batch of labels.
    mode: Specifies whether training or evaluation.
    params: Dictionary of parameters passed to the model.

  Returns:
    Model estimator w specifications.
  """

    if isinstance(features, dict):
        features = features['feature']

    mean_rgb = params['mean_rgb']
    stddev_rgb = params['stddev_rgb']
    features -= tf.constant(mean_rgb, shape=[1, 1, 3], dtype=features.dtype)
    features /= tf.constant(stddev_rgb, shape=[1, 1, 3], dtype=features.dtype)

    train_batch_size = params['train_batch_size']

    steps_per_epoch = params['num_train_images'] / train_batch_size
    initial_learning_rate = params['base_learning_rate']
    num_label_classes = params['num_label_classes']

    network = resnet_model.resnet_50(num_classes=num_label_classes,
                                     data_format=params['data_format'])

    logits = network(inputs=features,
                     is_training=(mode == tf.estimator.ModeKeys.TRAIN))

    output_dir = params['output_dir']
    weight_decay = params['weight_decay']

    one_hot_labels = tf.one_hot(labels, num_label_classes)
    cross_entropy = tf.losses.softmax_cross_entropy(
        logits=logits, onehot_labels=one_hot_labels, label_smoothing=0.1)

    loss = cross_entropy + weight_decay * tf.add_n([
        tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if 'batch_normalization' not in v.name
    ])
    host_call = None
    if mode == tf.estimator.ModeKeys.TRAIN:

        global_step = tf.train.get_global_step()

        steps_per_epoch = params['num_train_images'] / train_batch_size
        current_epoch = (tf.cast(global_step, tf.float32) / steps_per_epoch)
        learning_rate = compute_lr(current_epoch, initial_learning_rate,
                                   train_batch_size, params['lr_schedule'])
        optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                               momentum=params['momentum'],
                                               use_nesterov=True)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops), tf.name_scope('train'):
            train_op = optimizer.minimize(loss, global_step)

        with tf2.summary.create_file_writer(output_dir).as_default():
            with tf2.summary.record_if(True):
                tf2.summary.scalar('loss', loss, step=global_step)
                tf2.summary.scalar('learning_rate',
                                   learning_rate,
                                   step=global_step)
                tf2.summary.scalar('current_epoch',
                                   current_epoch,
                                   step=global_step)
                tf2.summary.scalar('steps_per_epoch',
                                   steps_per_epoch,
                                   step=global_step)
                tf2.summary.scalar('weight_decay',
                                   weight_decay,
                                   step=global_step)

            tf.summary.all_v2_summary_ops()

    else:
        train_op = None

    eval_metrics = {}
    if mode == tf.estimator.ModeKeys.EVAL:
        train_op = None
        predictions = tf.argmax(logits, axis=1)
        eval_metrics['top_1_accuracy'] = tf.metrics.accuracy(
            labels, predictions)
        in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
        eval_metrics['top_5_accuracy'] = tf.metrics.mean(in_top_5)

    return tf.estimator.EstimatorSpec(training_hooks=host_call,
                                      mode=mode,
                                      loss=loss,
                                      train_op=train_op,
                                      eval_metric_ops=eval_metrics)