def body2(j, encoded_val_keys, encoded_val_queries, encoded_val_values): """Loop body.""" temp = eval_batch[j * FLAGS.batch_size:(j + 1) * FLAGS.batch_size, :, :, :] temp_keys, temp_queries, temp_values = model.cnn_encoder( temp, reuse=True, is_training=False) encoded_val_keys = tf.concat([encoded_val_keys, temp_keys], 0) encoded_val_queries = tf.concat([encoded_val_queries, temp_queries], 0) encoded_val_values = tf.concat([encoded_val_values, temp_values], 0) return j + 1, encoded_val_keys, encoded_val_queries, encoded_val_values
def body(i, encoded_cand_keys_val, encoded_cand_queries_val, encoded_cand_values_val): """Loop body.""" temp = image_eval_cand[i * FLAGS.example_cand_size:(i + 1) * FLAGS.example_cand_size, :, :, :] temp_keys, temp_queries, temp_values = model.cnn_encoder( temp, reuse=True, is_training=False) encoded_cand_keys_val = tf.concat([encoded_cand_keys_val, temp_keys], 0) encoded_cand_queries_val = tf.concat( [encoded_cand_queries_val, temp_queries], 0) encoded_cand_values_val = tf.concat([encoded_cand_values_val, temp_values], 0) return i+1, encoded_cand_keys_val, encoded_cand_queries_val, \ encoded_cand_values_val
def inference(input_image, m_encoded_test_cand_keys, m_encoded_test_cand_values, m_label_test_cand): """Constructs inference graph.""" processed_input_image = input_data.parse_function_test(input_image) _, encoded_query, _ = model.cnn_encoder( processed_input_image, reuse=False, is_training=False) weighted_encoded_test, weight_coefs_test = model.relational_attention( encoded_query, tf.constant(m_encoded_test_cand_keys), tf.constant(m_encoded_test_cand_values), reuse=False) _, prediction_weighted_test = model.classify( weighted_encoded_test, reuse=False) predicted_class = tf.argmax(prediction_weighted_test, axis=1) expl_per_class = tf.py_func( utils.class_explainability, (tf.constant(m_label_test_cand), weight_coefs_test), tf.float32) confidence = tf.reduce_max(expl_per_class, axis=1) return predicted_class, confidence, weight_coefs_test
def main(unused_argv): """Main function.""" # Load training and eval data - this portion can be modified if the data is # imported from other sources. (m_train_data, m_train_labels), (m_eval_data, m_eval_labels) = \ tf.keras.datasets.fashion_mnist.load_data() train_dataset = tf.data.Dataset.from_tensor_slices( (m_train_data, m_train_labels)) eval_dataset = tf.data.Dataset.from_tensor_slices( (m_eval_data, m_eval_labels)) train_dataset = train_dataset.map(input_data.parse_function_train) eval_dataset = eval_dataset.map(input_data.parse_function_eval) eval_batch_size = int( math.floor(len(m_eval_data) / FLAGS.batch_size) * FLAGS.batch_size) train_batch = train_dataset.repeat().batch(FLAGS.batch_size) train_cand = train_dataset.repeat().batch(FLAGS.example_cand_size) eval_cand = train_dataset.repeat().batch(FLAGS.eval_cand_size) eval_batch = eval_dataset.repeat().batch(eval_batch_size) iter_train = train_batch.make_initializable_iterator() iter_train_cand = train_cand.make_initializable_iterator() iter_eval_cand = eval_cand.make_initializable_iterator() iter_eval = eval_batch.make_initializable_iterator() image_batch, _, label_batch = iter_train.get_next() image_train_cand, _, _ = iter_train_cand.get_next() image_eval_cand, orig_image_eval_cand, label_eval_cand = iter_eval_cand.get_next( ) eval_batch, orig_eval_batch, eval_labels = iter_eval.get_next() # Model and loss definitions _, encoded_batch_queries, encoded_batch_values = model.cnn_encoder( image_batch, reuse=False, is_training=True) encoded_cand_keys, _, encoded_cand_values = model.cnn_encoder( image_train_cand, reuse=True, is_training=True) weighted_encoded_batch, weight_coefs_batch = model.relational_attention( encoded_batch_queries, encoded_cand_keys, encoded_cand_values, normalization=FLAGS.normalization) tf.summary.scalar( "Average max. coef. train", tf.reduce_mean(tf.reduce_max(weight_coefs_batch, axis=1))) # Sparsity regularization entropy_weights = tf.reduce_sum( -weight_coefs_batch * tf.log(FLAGS.epsilon_sparsity + weight_coefs_batch), axis=1) sparsity_loss = tf.reduce_mean(entropy_weights) - tf.log( FLAGS.epsilon_sparsity + tf.constant(FLAGS.example_cand_size, dtype=tf.float32)) tf.summary.scalar("Sparsity entropy loss", sparsity_loss) # Intermediate loss joint_encoded_batch = (1 - FLAGS.alpha_intermediate) * encoded_batch_values \ + FLAGS.alpha_intermediate * weighted_encoded_batch logits_joint_batch, _ = model.classify(joint_encoded_batch, reuse=False) softmax_joint_op = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits_joint_batch, labels=label_batch)) # Self loss logits_orig_batch, _ = model.classify(encoded_batch_values, reuse=True) softmax_orig_key_op = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits_orig_batch, labels=label_batch)) # Prototype combination loss logits_weighted_batch, _ = model.classify(weighted_encoded_batch, reuse=True) softmax_weighted_op = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits_weighted_batch, labels=label_batch)) train_loss_op = softmax_orig_key_op + softmax_weighted_op + \ softmax_joint_op + FLAGS.sparsity_weight * sparsity_loss tf.summary.scalar("Total loss", train_loss_op) global_step = tf.train.get_or_create_global_step() learning_rate = tf.train.exponential_decay(FLAGS.init_learning_rate, global_step=global_step, decay_steps=FLAGS.decay_every, decay_rate=FLAGS.decay_rate) optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) tf.summary.scalar("Learning rate", learning_rate) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): gvs = optimizer.compute_gradients(train_loss_op) capped_gvs = [(tf.clip_by_value(grad, -FLAGS.gradient_thresh, FLAGS.gradient_thresh), var) for grad, var in gvs] train_op = optimizer.apply_gradients(capped_gvs, global_step=global_step) # Evaluate model # Process sequentially to avoid out-of-memory. i = tf.constant(0) encoded_cand_keys_val = tf.zeros([0, FLAGS.attention_dim]) encoded_cand_queries_val = tf.zeros([0, FLAGS.attention_dim]) encoded_cand_values_val = tf.zeros([0, FLAGS.val_dim]) def cond(i, unused_l1, unused_l2, unused_l3): return i < int( math.ceil(FLAGS.eval_cand_size / FLAGS.example_cand_size)) def body(i, encoded_cand_keys_val, encoded_cand_queries_val, encoded_cand_values_val): """Loop body.""" temp = image_eval_cand[i * FLAGS.example_cand_size:(i + 1) * FLAGS.example_cand_size, :, :, :] temp_keys, temp_queries, temp_values = model.cnn_encoder( temp, reuse=True, is_training=False) encoded_cand_keys_val = tf.concat([encoded_cand_keys_val, temp_keys], 0) encoded_cand_queries_val = tf.concat( [encoded_cand_queries_val, temp_queries], 0) encoded_cand_values_val = tf.concat( [encoded_cand_values_val, temp_values], 0) return i+1, encoded_cand_keys_val, encoded_cand_queries_val, \ encoded_cand_values_val _, encoded_cand_keys_val, encoded_cand_queries_val, \ encoded_cand_values_val, = tf.while_loop( cond, body, [i, encoded_cand_keys_val, encoded_cand_queries_val, encoded_cand_values_val], shape_invariants=[ i.get_shape(), tf.TensorShape([None, FLAGS.attention_dim]), tf.TensorShape([None, FLAGS.attention_dim]), tf.TensorShape([None, FLAGS.val_dim])]) j = tf.constant(0) encoded_val_keys = tf.zeros([0, FLAGS.attention_dim]) encoded_val_queries = tf.zeros([0, FLAGS.attention_dim]) encoded_val_values = tf.zeros([0, FLAGS.val_dim]) def cond2(j, unused_j1, unused_j2, unused_j3): return j < int(math.ceil(eval_batch_size / FLAGS.batch_size)) def body2(j, encoded_val_keys, encoded_val_queries, encoded_val_values): """Loop body.""" temp = eval_batch[j * FLAGS.batch_size:(j + 1) * FLAGS.batch_size, :, :, :] temp_keys, temp_queries, temp_values = model.cnn_encoder( temp, reuse=True, is_training=False) encoded_val_keys = tf.concat([encoded_val_keys, temp_keys], 0) encoded_val_queries = tf.concat([encoded_val_queries, temp_queries], 0) encoded_val_values = tf.concat([encoded_val_values, temp_values], 0) return j + 1, encoded_val_keys, encoded_val_queries, encoded_val_values _, encoded_val_keys, encoded_val_queries, \ encoded_val_values = tf.while_loop( cond2, body2, [ j, encoded_val_keys, encoded_val_queries, encoded_val_values], shape_invariants=[ j.get_shape(), tf.TensorShape([None, FLAGS.attention_dim]), tf.TensorShape([None, FLAGS.attention_dim]), tf.TensorShape([None, FLAGS.val_dim])]) weighted_encoded_val, weight_coefs_val = model.relational_attention( encoded_val_queries, encoded_cand_keys_val, encoded_cand_values_val, normalization=FLAGS.normalization) # Coefficient distribution tf.summary.scalar("Average max. coefficient val", tf.reduce_mean(tf.reduce_max(weight_coefs_val, axis=1))) # Analysis of median number of prototypes above a certain # confidence threshold. sorted_weights = tf.contrib.framework.sort(weight_coefs_val, direction="DESCENDING") cum_sorted_weights = tf.cumsum(sorted_weights, axis=1) for threshold in [0.5, 0.9, 0.95]: num_examples_thresh = tf.shape(sorted_weights)[1] + 1 - tf.reduce_sum( tf.cast(cum_sorted_weights > threshold, tf.int32), axis=1) tf.summary.histogram( "Number of samples for explainability above " + str(threshold), num_examples_thresh) tf.summary.scalar( "Median number of samples for explainability above " + str(threshold), tf.contrib.distributions.percentile(num_examples_thresh, q=50)) expl_per_class = tf.py_func(utils.class_explainability, (label_eval_cand, weight_coefs_val), tf.float32) max_expl = tf.reduce_max(expl_per_class, axis=1) tf.summary.histogram("Maximum per-class explainability", max_expl) _, prediction_val = model.classify(encoded_val_values, reuse=True) _, prediction_weighted_val = model.classify(weighted_encoded_val, reuse=True) val_eq_op = tf.equal(tf.cast(tf.argmax(prediction_val, 1), dtype=tf.int32), eval_labels) val_acc_op = tf.reduce_mean(tf.cast(val_eq_op, dtype=tf.float32)) tf.summary.scalar("Val accuracy input query", val_acc_op) val_weighted_eq_op = tf.equal( tf.cast(tf.argmax(prediction_weighted_val, 1), dtype=tf.int32), eval_labels) val_weighted_acc_op = tf.reduce_mean( tf.cast(val_weighted_eq_op, dtype=tf.float32)) tf.summary.scalar("Val accuracy weighted prototypes", val_weighted_acc_op) conf_wrong = tf.reduce_mean( (1 - tf.cast(val_weighted_eq_op, tf.float32)) * max_expl) tf.summary.scalar("Val average confidence of wrong decisions", conf_wrong) conf_right = tf.reduce_mean( tf.cast(val_weighted_eq_op, tf.float32) * max_expl) tf.summary.scalar("Val average confidence of right decisions", conf_right) # Confidence-controlled prediction for ti in [0.5, 0.8, 0.9, 0.95, 0.99, 0.999]: mask = tf.cast(tf.greater(max_expl, ti), tf.float32) acc_tot = tf.reduce_sum(tf.cast(val_weighted_eq_op, tf.float32) * mask) conf_tot = tf.reduce_sum(mask) tf.summary.scalar("Val accurate ratio for confidence above " + str(ti), acc_tot / conf_tot) tf.summary.scalar("Val total ratio for confidence above " + str(ti), conf_tot / eval_batch_size) # Visualization of example images and corresponding prototypes for image_ind in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]: tf.summary.image( "Input image " + str(image_ind), tf.expand_dims(orig_eval_batch[image_ind, :, :, :], 0)) mask = tf.greater(weight_coefs_val[image_ind, :], 0.05) mask = tf.squeeze(mask) mask.set_shape([None]) relational_attention_images = tf.boolean_mask(orig_image_eval_cand, mask, axis=0) relational_attention_weight_coefs = tf.boolean_mask(tf.squeeze( weight_coefs_val[image_ind, :]), mask, axis=0) annotated_images = utils.tf_put_text( relational_attention_images, relational_attention_weight_coefs) tf.summary.image("Prototype images for image " + str(image_ind), annotated_images) # Training setup init = (tf.global_variables_initializer(), tf.local_variables_initializer()) saver_all = tf.train.Saver() summaries = tf.summary.merge_all() with tf.Session() as sess: summary_writer = tf.summary.FileWriter("./tflog/" + model_name, sess.graph) sess.run(init) sess.run(iter_train.initializer) sess.run(iter_train_cand.initializer) sess.run(iter_eval_cand.initializer) sess.run(iter_eval.initializer) for step in range(1, FLAGS.num_steps): if step % FLAGS.display_step == 0: _, train_loss = sess.run([train_op, train_loss_op]) print("Step " + str(step) + " , Training loss = " + "{:.4f}".format(train_loss)) else: sess.run(train_op) if step % FLAGS.val_step == 0: val_acc, merged_summary = sess.run( [val_weighted_acc_op, summaries]) print("Step " + str(step) + " , Val Accuracy = " + "{:.4f}".format(val_acc)) summary_writer.add_summary(merged_summary, step) if step % FLAGS.save_step == 0: saver_all.save(sess, checkpoint_name)