def testFlopsVanilla(self): batch_size = 3 height, width = 224, 224 num_classes = 1001 with self.test_session() as sess: images = tf.random_uniform((batch_size, height, width, 3)) with slim.arg_scope( imagenet_model.resnet_arg_scope(is_training=False)): _, end_points = imagenet_model.get_network( images, [101], num_classes, 'vanilla') flops = sess.run(end_points['flops']) # TF graph_metrics value: 15614055401 (0.1% difference) expected_flops = 15602814976 self.assertAllEqual(flops, [expected_flops] * 3)
def _runBatch(self, is_training, model_type, model=[2, 2, 2, 2]): batch_size = 2 height, width = 128, 128 num_classes = 10 with self.test_session() as sess: images = tf.random_uniform((batch_size, height, width, 3)) with slim.arg_scope( imagenet_model.resnet_arg_scope(is_training=is_training)): logits, end_points = imagenet_model.get_network( images, model, num_classes, model_type='sact', base_channels=1) if model_type in ('act', 'act_early_stopping', 'sact'): metrics = summary_utils.act_metric_map( end_points, not is_training) metrics.update( summary_utils.flops_metric_map(end_points, not is_training)) else: metrics = {} if is_training: labels = tf.random_uniform((batch_size, ), maxval=num_classes, dtype=tf.int32) one_hot_labels = slim.one_hot_encoding(labels, num_classes) tf.losses.softmax_cross_entropy(onehot_labels=one_hot_labels, logits=logits, label_smoothing=0.1, weights=1.0) if model_type in ('act', 'act_early_stopping', 'sact'): training_utils.add_all_ponder_costs(end_points, weights=1.0) total_loss = tf.losses.get_total_loss() optimizer = tf.train.MomentumOptimizer(0.1, 0.9) train_op = slim.learning.create_train_op(total_loss, optimizer) sess.run(tf.global_variables_initializer()) sess.run((train_op, metrics)) else: sess.run([ tf.local_variables_initializer(), tf.global_variables_initializer() ]) logits_out, metrics_out = sess.run((logits, metrics)) self.assertEqual(logits_out.shape, (batch_size, num_classes))
def main(_): if not tf.gfile.Exists(FLAGS.output_dir): tf.gfile.MakeDirs(FLAGS.output_dir) assert FLAGS.model is not None assert FLAGS.model_type in ('vanilla', 'act', 'act_early_stopping', 'sact') assert FLAGS.dataset in ('imagenet', 'cifar') batch_size = 1 if FLAGS.dataset == 'imagenet': height, width = 224, 224 num_classes = 1001 elif FLAGS.dataset == 'cifar': height, width = 32, 32 num_classes = 10 images = tf.random_uniform((batch_size, height, width, 3)) model = utils.split_and_int(FLAGS.model) # Define the model if FLAGS.dataset == 'imagenet': with slim.arg_scope( imagenet_model.resnet_arg_scope(is_training=False)): logits, end_points = imagenet_model.get_network( images, model, num_classes, model_type=FLAGS.model_type) elif FLAGS.dataset == 'cifar': # Define the model: with slim.arg_scope(cifar_model.resnet_arg_scope(is_training=False)): logits, end_points = cifar_model.resnet( images, model=model, num_classes=num_classes, model_type=FLAGS.model_type) tf_global_step = slim.get_or_create_global_step() checkpoint_path = tf.train.latest_checkpoint(FLAGS.input_dir) assert checkpoint_path is not None saver = tf.train.Saver(write_version=2) with tf.Session() as sess: saver.restore(sess, checkpoint_path) saver.save(sess, FLAGS.output_dir + '/model', global_step=tf_global_step)
def testVisualizationBasic(self): batch_size = 5 height, width = 128, 128 num_classes = 10 is_training = False num_images = 3 border = 5 with self.test_session() as sess: images = tf.random_uniform((batch_size, height, width, 3)) with slim.arg_scope( imagenet_model.resnet_arg_scope(is_training=is_training)): logits, end_points = imagenet_model.get_network( images, [2, 2, 2, 2], num_classes, model_type='sact', base_channels=1) vis_ponder = summary_utils.sact_image_heatmap( end_points, 'ponder_cost', num_images=num_images, alpha=0.75, border=border) vis_units = summary_utils.sact_image_heatmap( end_points, 'num_units', num_images=num_images, alpha=0.75, border=border) sess.run(tf.global_variables_initializer()) vis_ponder_out, vis_units_out = sess.run( [vis_ponder, vis_units]) self.assertEqual(vis_ponder_out.shape, (num_images, height, width * 2 + border, 3)) self.assertEqual(vis_units_out.shape, (num_images, height, width * 2 + border, 3))
def main(_): assert FLAGS.model_type in ('act', 'act_early_stopping', 'sact') g = tf.Graph() with g.as_default(): data_tuple = imagenet_data_provider.provide_data( FLAGS.split_name, FLAGS.batch_size, dataset_dir=FLAGS.dataset_dir, is_training=False) images, labels, _, num_classes = data_tuple # Define the model: with slim.arg_scope( imagenet_model.resnet_arg_scope(is_training=False)): model = utils.split_and_int(FLAGS.model) logits, end_points = imagenet_model.get_network( images, model, num_classes, model_type=FLAGS.model_type) summary_utils.export_to_h5(FLAGS.checkpoint_dir, FLAGS.export_path, images, end_points, FLAGS.num_examples, FLAGS.batch_size, FLAGS.model_type == 'sact')
def main(_): g = tf.Graph() with g.as_default(): data_tuple = imagenet_data_provider.provide_data( FLAGS.split_name, FLAGS.batch_size, dataset_dir=FLAGS.dataset_dir, is_training=False, image_size=FLAGS.image_size) images, one_hot_labels, examples_per_epoch, num_classes = data_tuple # Define the model: with slim.arg_scope( imagenet_model.resnet_arg_scope(is_training=False)): model = utils.split_and_int(FLAGS.model) logits, end_points = imagenet_model.get_network( images, model, num_classes, model_type=FLAGS.model_type) predictions = tf.argmax(end_points['predictions'], 1) # Define the metrics: labels = tf.argmax(one_hot_labels, 1) metric_map = { 'eval/Accuracy': tf.contrib.metrics.streaming_accuracy(predictions, labels), 'eval/Recall@5': tf.contrib.metrics.streaming_sparse_recall_at_k( end_points['predictions'], tf.expand_dims(labels, 1), 5), } metric_map.update(summary_utils.flops_metric_map(end_points, True)) if FLAGS.model_type in ['act', 'act_early_stopping', 'sact']: metric_map.update( summary_utils.act_metric_map(end_points, True)) names_to_values, names_to_updates = tf.contrib.metrics.aggregate_metric_map( metric_map) for name, value in names_to_values.iteritems(): summ = tf.summary.scalar(name, value, collections=[]) summ = tf.Print(summ, [value], name) tf.add_to_collection(tf.GraphKeys.SUMMARIES, summ) if FLAGS.model_type == 'sact': summary_utils.add_heatmaps_image_summary(end_points, border=10) # This ensures that we make a single pass over all of the data. num_batches = math.ceil(FLAGS.num_examples / float(FLAGS.batch_size)) if not FLAGS.evaluate_once: eval_function = slim.evaluation.evaluation_loop checkpoint_path = FLAGS.checkpoint_dir kwargs = {'eval_interval_secs': FLAGS.eval_interval_secs} else: eval_function = slim.evaluation.evaluate_once checkpoint_path = tf.train.latest_checkpoint( FLAGS.checkpoint_dir) assert checkpoint_path is not None kwargs = {} eval_function(FLAGS.master, checkpoint_path, logdir=FLAGS.eval_dir, num_evals=num_batches, eval_op=names_to_updates.values(), **kwargs)
def main(_): if not tf.gfile.Exists(FLAGS.output_dir): tf.gfile.MakeDirs(FLAGS.output_dir) num_classes = 1001 path = tf.placeholder(tf.string) contents = tf.read_file(path) image = tf.image.decode_jpeg(contents, channels=3) image = tf.image.convert_image_dtype(image, dtype=tf.float32) images = tf.expand_dims(image, 0) images.set_shape([1, None, None, 3]) if FLAGS.image_size: sh = tf.shape(image) height, width = tf.to_float(sh[0]), tf.to_float(sh[1]) longer_size = tf.constant(FLAGS.image_size, dtype=tf.float32) new_size = tf.cond( height >= width, lambda: (longer_size, (width / height) * longer_size), lambda: ((height / width) * longer_size, longer_size)) images_resized = tf.image.resize_images( images, size=tf.to_int32(tf.stack(new_size)), method=tf.image.ResizeMethod.BICUBIC) else: images_resized = images images_resized = preprocessing(images_resized) # Define the model: with slim.arg_scope(imagenet_model.resnet_arg_scope(is_training=False)): model = utils.split_and_int(FLAGS.model) logits, end_points = imagenet_model.get_network(images_resized, model, num_classes, model_type='sact') ponder_cost_map = summary_utils.sact_map(end_points, 'ponder_cost') checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) assert checkpoint_path is not None saver = tf.train.Saver() sess = tf.Session() saver.restore(sess, checkpoint_path) for current_path in glob.glob(FLAGS.images_pattern): print('Processing {}'.format(current_path)) [image_resized_out, ponder_cost_map_out] = sess.run([ tf.squeeze(reverse_preprocessing(images_resized), 0), tf.squeeze(ponder_cost_map, [0, 3]) ], feed_dict={path: current_path}) basename = os.path.splitext(os.path.basename(current_path))[0] if FLAGS.image_size: matplotlib.image.imsave( os.path.join(FLAGS.output_dir, '{}_im.jpg'.format(basename)), image_resized_out) matplotlib.image.imsave(os.path.join(FLAGS.output_dir, '{}_ponder.jpg'.format(basename)), ponder_cost_map_out, cmap='viridis') min_ponder = ponder_cost_map_out.min() max_ponder = ponder_cost_map_out.max() print('Minimum/maximum ponder cost {:.2f}/{:.2f}'.format( min_ponder, max_ponder)) fig = plt.figure(figsize=(0.2, 2)) ax = fig.add_axes([0.0, 0.0, 1.0, 1.0]) cb = matplotlib.colorbar.ColorbarBase(ax, cmap='viridis', norm=matplotlib.colors.Normalize( vmin=min_ponder, vmax=max_ponder)) ax.tick_params(labelsize=12) filename = os.path.join(FLAGS.output_dir, '{}_colorbar.pdf'.format(basename)) plt.savefig(filename, bbox_inches='tight')
def main(_): g = tf.Graph() with g.as_default(): # If ps_tasks is zero, the local device is used. When using multiple # (non-local) replicas, the ReplicaDeviceSetter distributes the variables # across the different devices. with tf.device( tf.train.replica_device_setter(FLAGS.ps_tasks, merge_devices=True)): data_tuple = imagenet_data_provider.provide_data( FLAGS.split_name, FLAGS.batch_size, dataset_dir=FLAGS.dataset_dir, is_training=True, image_size=FLAGS.image_size) images, labels, examples_per_epoch, num_classes = data_tuple # Define the model: with slim.arg_scope( imagenet_model.resnet_arg_scope(is_training=True)): model = utils.split_and_int(FLAGS.model) logits, end_points = imagenet_model.get_network( images, model, num_classes, model_type=FLAGS.model_type) # Specify the loss function: tf.losses.softmax_cross_entropy(logits, labels, label_smoothing=0.1, weights=1.0) if FLAGS.model_type in ('act', 'act_early_stopping', 'sact'): training_utils.add_all_ponder_costs(end_points, weights=FLAGS.tau) total_loss = tf.losses.get_total_loss() # Configure the learning rate using an exponetial decay. decay_steps = int(examples_per_epoch / FLAGS.batch_size * FLAGS.num_epochs_per_decay) learning_rate = tf.train.exponential_decay( FLAGS.learning_rate, slim.get_or_create_global_step(), decay_steps, FLAGS.learning_rate_decay_factor, staircase=True) opt = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum) init_fn = training_utils.finetuning_init_fn( FLAGS.finetune_path) train_tensor = slim.learning.create_train_op( total_loss, optimizer=opt, update_ops=tf.get_collection(tf.GraphKeys.UPDATE_OPS)) # Summaries: tf.summary.scalar('losses/Total Loss', total_loss) tf.summary.scalar('training/Learning Rate', learning_rate) metric_map = { } # summary_utils.flops_metric_map(end_points, False) if FLAGS.model_type in ('act', 'act_early_stopping', 'sact'): metric_map.update( summary_utils.act_metric_map(end_points, False)) for name, value in metric_map.iteritems(): tf.summary.scalar(name, value) if FLAGS.model_type == 'sact': summary_utils.add_heatmaps_image_summary(end_points, border=10) startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps slim.learning.train( train_tensor, init_fn=init_fn, logdir=FLAGS.train_log_dir, master=FLAGS.master, is_chief=(FLAGS.task == 0), startup_delay_steps=startup_delay_steps, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs)