def testFlopsVanilla(self):
        batch_size = 3
        height, width = 224, 224
        num_classes = 1001

        with self.test_session() as sess:
            images = tf.random_uniform((batch_size, height, width, 3))
            with slim.arg_scope(
                    imagenet_model.resnet_arg_scope(is_training=False)):
                _, end_points = imagenet_model.get_network(
                    images, [101], num_classes, 'vanilla')
                flops = sess.run(end_points['flops'])
                # TF graph_metrics value: 15614055401 (0.1% difference)
                expected_flops = 15602814976
                self.assertAllEqual(flops, [expected_flops] * 3)
    def _runBatch(self, is_training, model_type, model=[2, 2, 2, 2]):
        batch_size = 2
        height, width = 128, 128
        num_classes = 10

        with self.test_session() as sess:
            images = tf.random_uniform((batch_size, height, width, 3))
            with slim.arg_scope(
                    imagenet_model.resnet_arg_scope(is_training=is_training)):
                logits, end_points = imagenet_model.get_network(
                    images,
                    model,
                    num_classes,
                    model_type='sact',
                    base_channels=1)
                if model_type in ('act', 'act_early_stopping', 'sact'):
                    metrics = summary_utils.act_metric_map(
                        end_points, not is_training)
                    metrics.update(
                        summary_utils.flops_metric_map(end_points,
                                                       not is_training))
                else:
                    metrics = {}

            if is_training:
                labels = tf.random_uniform((batch_size, ),
                                           maxval=num_classes,
                                           dtype=tf.int32)
                one_hot_labels = slim.one_hot_encoding(labels, num_classes)
                tf.losses.softmax_cross_entropy(onehot_labels=one_hot_labels,
                                                logits=logits,
                                                label_smoothing=0.1,
                                                weights=1.0)
                if model_type in ('act', 'act_early_stopping', 'sact'):
                    training_utils.add_all_ponder_costs(end_points,
                                                        weights=1.0)
                total_loss = tf.losses.get_total_loss()
                optimizer = tf.train.MomentumOptimizer(0.1, 0.9)
                train_op = slim.learning.create_train_op(total_loss, optimizer)
                sess.run(tf.global_variables_initializer())
                sess.run((train_op, metrics))
            else:
                sess.run([
                    tf.local_variables_initializer(),
                    tf.global_variables_initializer()
                ])
                logits_out, metrics_out = sess.run((logits, metrics))
                self.assertEqual(logits_out.shape, (batch_size, num_classes))
def main(_):
    if not tf.gfile.Exists(FLAGS.output_dir):
        tf.gfile.MakeDirs(FLAGS.output_dir)

    assert FLAGS.model is not None
    assert FLAGS.model_type in ('vanilla', 'act', 'act_early_stopping', 'sact')
    assert FLAGS.dataset in ('imagenet', 'cifar')

    batch_size = 1

    if FLAGS.dataset == 'imagenet':
        height, width = 224, 224
        num_classes = 1001
    elif FLAGS.dataset == 'cifar':
        height, width = 32, 32
        num_classes = 10

    images = tf.random_uniform((batch_size, height, width, 3))
    model = utils.split_and_int(FLAGS.model)

    # Define the model
    if FLAGS.dataset == 'imagenet':
        with slim.arg_scope(
                imagenet_model.resnet_arg_scope(is_training=False)):
            logits, end_points = imagenet_model.get_network(
                images, model, num_classes, model_type=FLAGS.model_type)
    elif FLAGS.dataset == 'cifar':
        # Define the model:
        with slim.arg_scope(cifar_model.resnet_arg_scope(is_training=False)):
            logits, end_points = cifar_model.resnet(
                images,
                model=model,
                num_classes=num_classes,
                model_type=FLAGS.model_type)

    tf_global_step = slim.get_or_create_global_step()

    checkpoint_path = tf.train.latest_checkpoint(FLAGS.input_dir)
    assert checkpoint_path is not None

    saver = tf.train.Saver(write_version=2)

    with tf.Session() as sess:
        saver.restore(sess, checkpoint_path)
        saver.save(sess,
                   FLAGS.output_dir + '/model',
                   global_step=tf_global_step)
    def testVisualizationBasic(self):
        batch_size = 5
        height, width = 128, 128
        num_classes = 10
        is_training = False
        num_images = 3
        border = 5

        with self.test_session() as sess:
            images = tf.random_uniform((batch_size, height, width, 3))
            with slim.arg_scope(
                    imagenet_model.resnet_arg_scope(is_training=is_training)):
                logits, end_points = imagenet_model.get_network(
                    images, [2, 2, 2, 2],
                    num_classes,
                    model_type='sact',
                    base_channels=1)

                vis_ponder = summary_utils.sact_image_heatmap(
                    end_points,
                    'ponder_cost',
                    num_images=num_images,
                    alpha=0.75,
                    border=border)
                vis_units = summary_utils.sact_image_heatmap(
                    end_points,
                    'num_units',
                    num_images=num_images,
                    alpha=0.75,
                    border=border)

                sess.run(tf.global_variables_initializer())
                vis_ponder_out, vis_units_out = sess.run(
                    [vis_ponder, vis_units])
                self.assertEqual(vis_ponder_out.shape,
                                 (num_images, height, width * 2 + border, 3))
                self.assertEqual(vis_units_out.shape,
                                 (num_images, height, width * 2 + border, 3))
def main(_):
    assert FLAGS.model_type in ('act', 'act_early_stopping', 'sact')

    g = tf.Graph()
    with g.as_default():
        data_tuple = imagenet_data_provider.provide_data(
            FLAGS.split_name,
            FLAGS.batch_size,
            dataset_dir=FLAGS.dataset_dir,
            is_training=False)
        images, labels, _, num_classes = data_tuple

        # Define the model:
        with slim.arg_scope(
                imagenet_model.resnet_arg_scope(is_training=False)):
            model = utils.split_and_int(FLAGS.model)
            logits, end_points = imagenet_model.get_network(
                images, model, num_classes, model_type=FLAGS.model_type)

            summary_utils.export_to_h5(FLAGS.checkpoint_dir, FLAGS.export_path,
                                       images, end_points, FLAGS.num_examples,
                                       FLAGS.batch_size,
                                       FLAGS.model_type == 'sact')
def main(_):
    g = tf.Graph()
    with g.as_default():
        data_tuple = imagenet_data_provider.provide_data(
            FLAGS.split_name,
            FLAGS.batch_size,
            dataset_dir=FLAGS.dataset_dir,
            is_training=False,
            image_size=FLAGS.image_size)
        images, one_hot_labels, examples_per_epoch, num_classes = data_tuple

        # Define the model:
        with slim.arg_scope(
                imagenet_model.resnet_arg_scope(is_training=False)):
            model = utils.split_and_int(FLAGS.model)
            logits, end_points = imagenet_model.get_network(
                images, model, num_classes, model_type=FLAGS.model_type)

            predictions = tf.argmax(end_points['predictions'], 1)

            # Define the metrics:
            labels = tf.argmax(one_hot_labels, 1)
            metric_map = {
                'eval/Accuracy':
                tf.contrib.metrics.streaming_accuracy(predictions, labels),
                'eval/Recall@5':
                tf.contrib.metrics.streaming_sparse_recall_at_k(
                    end_points['predictions'], tf.expand_dims(labels, 1), 5),
            }
            metric_map.update(summary_utils.flops_metric_map(end_points, True))
            if FLAGS.model_type in ['act', 'act_early_stopping', 'sact']:
                metric_map.update(
                    summary_utils.act_metric_map(end_points, True))

            names_to_values, names_to_updates = tf.contrib.metrics.aggregate_metric_map(
                metric_map)

            for name, value in names_to_values.iteritems():
                summ = tf.summary.scalar(name, value, collections=[])
                summ = tf.Print(summ, [value], name)
                tf.add_to_collection(tf.GraphKeys.SUMMARIES, summ)

            if FLAGS.model_type == 'sact':
                summary_utils.add_heatmaps_image_summary(end_points, border=10)

            # This ensures that we make a single pass over all of the data.
            num_batches = math.ceil(FLAGS.num_examples /
                                    float(FLAGS.batch_size))

            if not FLAGS.evaluate_once:
                eval_function = slim.evaluation.evaluation_loop
                checkpoint_path = FLAGS.checkpoint_dir
                kwargs = {'eval_interval_secs': FLAGS.eval_interval_secs}
            else:
                eval_function = slim.evaluation.evaluate_once
                checkpoint_path = tf.train.latest_checkpoint(
                    FLAGS.checkpoint_dir)
                assert checkpoint_path is not None
                kwargs = {}

            eval_function(FLAGS.master,
                          checkpoint_path,
                          logdir=FLAGS.eval_dir,
                          num_evals=num_batches,
                          eval_op=names_to_updates.values(),
                          **kwargs)
예제 #7
0
def main(_):
    if not tf.gfile.Exists(FLAGS.output_dir):
        tf.gfile.MakeDirs(FLAGS.output_dir)

    num_classes = 1001

    path = tf.placeholder(tf.string)
    contents = tf.read_file(path)
    image = tf.image.decode_jpeg(contents, channels=3)
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    images = tf.expand_dims(image, 0)
    images.set_shape([1, None, None, 3])

    if FLAGS.image_size:
        sh = tf.shape(image)
        height, width = tf.to_float(sh[0]), tf.to_float(sh[1])
        longer_size = tf.constant(FLAGS.image_size, dtype=tf.float32)

        new_size = tf.cond(
            height >= width, lambda: (longer_size,
                                      (width / height) * longer_size), lambda:
            ((height / width) * longer_size, longer_size))
        images_resized = tf.image.resize_images(
            images,
            size=tf.to_int32(tf.stack(new_size)),
            method=tf.image.ResizeMethod.BICUBIC)
    else:
        images_resized = images

    images_resized = preprocessing(images_resized)

    # Define the model:
    with slim.arg_scope(imagenet_model.resnet_arg_scope(is_training=False)):
        model = utils.split_and_int(FLAGS.model)
        logits, end_points = imagenet_model.get_network(images_resized,
                                                        model,
                                                        num_classes,
                                                        model_type='sact')
        ponder_cost_map = summary_utils.sact_map(end_points, 'ponder_cost')

    checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
    assert checkpoint_path is not None

    saver = tf.train.Saver()
    sess = tf.Session()

    saver.restore(sess, checkpoint_path)

    for current_path in glob.glob(FLAGS.images_pattern):
        print('Processing {}'.format(current_path))

        [image_resized_out,
         ponder_cost_map_out] = sess.run([
             tf.squeeze(reverse_preprocessing(images_resized), 0),
             tf.squeeze(ponder_cost_map, [0, 3])
         ],
                                         feed_dict={path: current_path})

        basename = os.path.splitext(os.path.basename(current_path))[0]
        if FLAGS.image_size:
            matplotlib.image.imsave(
                os.path.join(FLAGS.output_dir, '{}_im.jpg'.format(basename)),
                image_resized_out)
        matplotlib.image.imsave(os.path.join(FLAGS.output_dir,
                                             '{}_ponder.jpg'.format(basename)),
                                ponder_cost_map_out,
                                cmap='viridis')

        min_ponder = ponder_cost_map_out.min()
        max_ponder = ponder_cost_map_out.max()
        print('Minimum/maximum ponder cost {:.2f}/{:.2f}'.format(
            min_ponder, max_ponder))

        fig = plt.figure(figsize=(0.2, 2))
        ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
        cb = matplotlib.colorbar.ColorbarBase(ax,
                                              cmap='viridis',
                                              norm=matplotlib.colors.Normalize(
                                                  vmin=min_ponder,
                                                  vmax=max_ponder))
        ax.tick_params(labelsize=12)
        filename = os.path.join(FLAGS.output_dir,
                                '{}_colorbar.pdf'.format(basename))
        plt.savefig(filename, bbox_inches='tight')
def main(_):
    g = tf.Graph()
    with g.as_default():
        # If ps_tasks is zero, the local device is used. When using multiple
        # (non-local) replicas, the ReplicaDeviceSetter distributes the variables
        # across the different devices.
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               merge_devices=True)):
            data_tuple = imagenet_data_provider.provide_data(
                FLAGS.split_name,
                FLAGS.batch_size,
                dataset_dir=FLAGS.dataset_dir,
                is_training=True,
                image_size=FLAGS.image_size)
            images, labels, examples_per_epoch, num_classes = data_tuple

            # Define the model:
            with slim.arg_scope(
                    imagenet_model.resnet_arg_scope(is_training=True)):
                model = utils.split_and_int(FLAGS.model)
                logits, end_points = imagenet_model.get_network(
                    images, model, num_classes, model_type=FLAGS.model_type)

                # Specify the loss function:
                tf.losses.softmax_cross_entropy(logits,
                                                labels,
                                                label_smoothing=0.1,
                                                weights=1.0)
                if FLAGS.model_type in ('act', 'act_early_stopping', 'sact'):
                    training_utils.add_all_ponder_costs(end_points,
                                                        weights=FLAGS.tau)
                total_loss = tf.losses.get_total_loss()

                # Configure the learning rate using an exponetial decay.
                decay_steps = int(examples_per_epoch / FLAGS.batch_size *
                                  FLAGS.num_epochs_per_decay)

                learning_rate = tf.train.exponential_decay(
                    FLAGS.learning_rate,
                    slim.get_or_create_global_step(),
                    decay_steps,
                    FLAGS.learning_rate_decay_factor,
                    staircase=True)

                opt = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)

                init_fn = training_utils.finetuning_init_fn(
                    FLAGS.finetune_path)

                train_tensor = slim.learning.create_train_op(
                    total_loss,
                    optimizer=opt,
                    update_ops=tf.get_collection(tf.GraphKeys.UPDATE_OPS))

                # Summaries:
                tf.summary.scalar('losses/Total Loss', total_loss)
                tf.summary.scalar('training/Learning Rate', learning_rate)

                metric_map = {
                }  # summary_utils.flops_metric_map(end_points, False)
                if FLAGS.model_type in ('act', 'act_early_stopping', 'sact'):
                    metric_map.update(
                        summary_utils.act_metric_map(end_points, False))
                for name, value in metric_map.iteritems():
                    tf.summary.scalar(name, value)

                if FLAGS.model_type == 'sact':
                    summary_utils.add_heatmaps_image_summary(end_points,
                                                             border=10)

                startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps

                slim.learning.train(
                    train_tensor,
                    init_fn=init_fn,
                    logdir=FLAGS.train_log_dir,
                    master=FLAGS.master,
                    is_chief=(FLAGS.task == 0),
                    startup_delay_steps=startup_delay_steps,
                    save_summaries_secs=FLAGS.save_summaries_secs,
                    save_interval_secs=FLAGS.save_interval_secs)