def main(_):
    if not tf.gfile.Exists(FLAGS.output_dir):
        tf.gfile.MakeDirs(FLAGS.output_dir)

    assert FLAGS.model is not None
    assert FLAGS.model_type in ('vanilla', 'act', 'act_early_stopping', 'sact')
    assert FLAGS.dataset in ('imagenet', 'cifar')

    batch_size = 1

    if FLAGS.dataset == 'imagenet':
        height, width = 224, 224
        num_classes = 1001
    elif FLAGS.dataset == 'cifar':
        height, width = 32, 32
        num_classes = 10

    images = tf.random_uniform((batch_size, height, width, 3))
    model = utils.split_and_int(FLAGS.model)

    # Define the model
    if FLAGS.dataset == 'imagenet':
        with slim.arg_scope(
                imagenet_model.resnet_arg_scope(is_training=False)):
            logits, end_points = imagenet_model.get_network(
                images, model, num_classes, model_type=FLAGS.model_type)
    elif FLAGS.dataset == 'cifar':
        # Define the model:
        with slim.arg_scope(cifar_model.resnet_arg_scope(is_training=False)):
            logits, end_points = cifar_model.resnet(
                images,
                model=model,
                num_classes=num_classes,
                model_type=FLAGS.model_type)

    tf_global_step = slim.get_or_create_global_step()

    checkpoint_path = tf.train.latest_checkpoint(FLAGS.input_dir)
    assert checkpoint_path is not None

    saver = tf.train.Saver(write_version=2)

    with tf.Session() as sess:
        saver.restore(sess, checkpoint_path)
        saver.save(sess,
                   FLAGS.output_dir + '/model',
                   global_step=tf_global_step)
def main(_):
    assert FLAGS.model_type in ('act', 'act_early_stopping', 'sact')

    g = tf.Graph()
    with g.as_default():
        data_tuple = imagenet_data_provider.provide_data(
            FLAGS.split_name,
            FLAGS.batch_size,
            dataset_dir=FLAGS.dataset_dir,
            is_training=False)
        images, labels, _, num_classes = data_tuple

        # Define the model:
        with slim.arg_scope(
                imagenet_model.resnet_arg_scope(is_training=False)):
            model = utils.split_and_int(FLAGS.model)
            logits, end_points = imagenet_model.get_network(
                images, model, num_classes, model_type=FLAGS.model_type)

            summary_utils.export_to_h5(FLAGS.checkpoint_dir, FLAGS.export_path,
                                       images, end_points, FLAGS.num_examples,
                                       FLAGS.batch_size,
                                       FLAGS.model_type == 'sact')
def main(_):
    g = tf.Graph()
    with g.as_default():
        data_tuple = imagenet_data_provider.provide_data(
            FLAGS.split_name,
            FLAGS.batch_size,
            dataset_dir=FLAGS.dataset_dir,
            is_training=False,
            image_size=FLAGS.image_size)
        images, one_hot_labels, examples_per_epoch, num_classes = data_tuple

        # Define the model:
        with slim.arg_scope(
                imagenet_model.resnet_arg_scope(is_training=False)):
            model = utils.split_and_int(FLAGS.model)
            logits, end_points = imagenet_model.get_network(
                images, model, num_classes, model_type=FLAGS.model_type)

            predictions = tf.argmax(end_points['predictions'], 1)

            # Define the metrics:
            labels = tf.argmax(one_hot_labels, 1)
            metric_map = {
                'eval/Accuracy':
                tf.contrib.metrics.streaming_accuracy(predictions, labels),
                'eval/Recall@5':
                tf.contrib.metrics.streaming_sparse_recall_at_k(
                    end_points['predictions'], tf.expand_dims(labels, 1), 5),
            }
            metric_map.update(summary_utils.flops_metric_map(end_points, True))
            if FLAGS.model_type in ['act', 'act_early_stopping', 'sact']:
                metric_map.update(
                    summary_utils.act_metric_map(end_points, True))

            names_to_values, names_to_updates = tf.contrib.metrics.aggregate_metric_map(
                metric_map)

            for name, value in names_to_values.iteritems():
                summ = tf.summary.scalar(name, value, collections=[])
                summ = tf.Print(summ, [value], name)
                tf.add_to_collection(tf.GraphKeys.SUMMARIES, summ)

            if FLAGS.model_type == 'sact':
                summary_utils.add_heatmaps_image_summary(end_points, border=10)

            # This ensures that we make a single pass over all of the data.
            num_batches = math.ceil(FLAGS.num_examples /
                                    float(FLAGS.batch_size))

            if not FLAGS.evaluate_once:
                eval_function = slim.evaluation.evaluation_loop
                checkpoint_path = FLAGS.checkpoint_dir
                kwargs = {'eval_interval_secs': FLAGS.eval_interval_secs}
            else:
                eval_function = slim.evaluation.evaluate_once
                checkpoint_path = tf.train.latest_checkpoint(
                    FLAGS.checkpoint_dir)
                assert checkpoint_path is not None
                kwargs = {}

            eval_function(FLAGS.master,
                          checkpoint_path,
                          logdir=FLAGS.eval_dir,
                          num_evals=num_batches,
                          eval_op=names_to_updates.values(),
                          **kwargs)
Esempio n. 4
0
def main(_):
    if not tf.gfile.Exists(FLAGS.output_dir):
        tf.gfile.MakeDirs(FLAGS.output_dir)

    num_classes = 1001

    path = tf.placeholder(tf.string)
    contents = tf.read_file(path)
    image = tf.image.decode_jpeg(contents, channels=3)
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    images = tf.expand_dims(image, 0)
    images.set_shape([1, None, None, 3])

    if FLAGS.image_size:
        sh = tf.shape(image)
        height, width = tf.to_float(sh[0]), tf.to_float(sh[1])
        longer_size = tf.constant(FLAGS.image_size, dtype=tf.float32)

        new_size = tf.cond(
            height >= width, lambda: (longer_size,
                                      (width / height) * longer_size), lambda:
            ((height / width) * longer_size, longer_size))
        images_resized = tf.image.resize_images(
            images,
            size=tf.to_int32(tf.stack(new_size)),
            method=tf.image.ResizeMethod.BICUBIC)
    else:
        images_resized = images

    images_resized = preprocessing(images_resized)

    # Define the model:
    with slim.arg_scope(imagenet_model.resnet_arg_scope(is_training=False)):
        model = utils.split_and_int(FLAGS.model)
        logits, end_points = imagenet_model.get_network(images_resized,
                                                        model,
                                                        num_classes,
                                                        model_type='sact')
        ponder_cost_map = summary_utils.sact_map(end_points, 'ponder_cost')

    checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
    assert checkpoint_path is not None

    saver = tf.train.Saver()
    sess = tf.Session()

    saver.restore(sess, checkpoint_path)

    for current_path in glob.glob(FLAGS.images_pattern):
        print('Processing {}'.format(current_path))

        [image_resized_out,
         ponder_cost_map_out] = sess.run([
             tf.squeeze(reverse_preprocessing(images_resized), 0),
             tf.squeeze(ponder_cost_map, [0, 3])
         ],
                                         feed_dict={path: current_path})

        basename = os.path.splitext(os.path.basename(current_path))[0]
        if FLAGS.image_size:
            matplotlib.image.imsave(
                os.path.join(FLAGS.output_dir, '{}_im.jpg'.format(basename)),
                image_resized_out)
        matplotlib.image.imsave(os.path.join(FLAGS.output_dir,
                                             '{}_ponder.jpg'.format(basename)),
                                ponder_cost_map_out,
                                cmap='viridis')

        min_ponder = ponder_cost_map_out.min()
        max_ponder = ponder_cost_map_out.max()
        print('Minimum/maximum ponder cost {:.2f}/{:.2f}'.format(
            min_ponder, max_ponder))

        fig = plt.figure(figsize=(0.2, 2))
        ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
        cb = matplotlib.colorbar.ColorbarBase(ax,
                                              cmap='viridis',
                                              norm=matplotlib.colors.Normalize(
                                                  vmin=min_ponder,
                                                  vmax=max_ponder))
        ax.tick_params(labelsize=12)
        filename = os.path.join(FLAGS.output_dir,
                                '{}_colorbar.pdf'.format(basename))
        plt.savefig(filename, bbox_inches='tight')
def evaluate():
  g = tf.Graph()
  with g.as_default():
    data_tuple = cifar_data_provider.provide_data(FLAGS.split_name,
                                                  FLAGS.eval_batch_size,
                                                  dataset_dir=FLAGS.dataset_dir)
    images, _, one_hot_labels, num_samples, num_classes = data_tuple

    # Define the model:
    with slim.arg_scope(cifar_model.resnet_arg_scope(is_training=False)):
      model = utils.split_and_int(FLAGS.model)
      logits, end_points = cifar_model.resnet(
          images,
          model=model,
          num_classes=num_classes,
          model_type=FLAGS.model_type)

      predictions = tf.argmax(logits, 1)

      tf.losses.softmax_cross_entropy(
          onehot_labels=one_hot_labels, logits=logits)
      if FLAGS.model_type in ('act', 'act_early_stopping', 'sact'):
        training_utils.add_all_ponder_costs(end_points, weights=FLAGS.tau)

      loss = tf.losses.get_total_loss()

      # Define the metrics:
      labels = tf.argmax(one_hot_labels, 1)
      metric_map = {
          'eval/Accuracy':
                tf.contrib.metrics.streaming_accuracy(predictions, labels),
          'eval/Mean Loss':
                tf.contrib.metrics.streaming_mean(loss),
      }
      metric_map.update(summary_utils.flops_metric_map(end_points, True))
      if FLAGS.model_type in ('act', 'act_early_stopping', 'sact'):
        metric_map.update(summary_utils.act_metric_map(end_points, True))
      names_to_values, names_to_updates = tf.contrib.metrics.aggregate_metric_map(
          metric_map)

      for name, value in names_to_values.iteritems():
        summ = tf.summary.scalar(name, value, collections=[])
        summ = tf.Print(summ, [value], name)
        tf.add_to_collection(tf.GraphKeys.SUMMARIES, summ)

      if FLAGS.model_type == 'sact':
        summary_utils.add_heatmaps_image_summary(end_points)

      # This ensures that we make a single pass over all of the data.
      num_batches = math.ceil(num_samples / float(FLAGS.eval_batch_size))

      if not FLAGS.evaluate_once:
        eval_function = slim.evaluation.evaluation_loop
        checkpoint_path = FLAGS.checkpoint_dir
        eval_kwargs = {'eval_interval_secs': FLAGS.eval_interval_secs}
      else:
        eval_function = slim.evaluation.evaluate_once
        checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
        assert checkpoint_path is not None
        eval_kwargs = {}

      config = tf.ConfigProto()
      config.gpu_options.allow_growth = True

      eval_function(
          FLAGS.master,
          checkpoint_path,
          logdir=FLAGS.eval_dir,
          num_evals=num_batches,
          eval_op=names_to_updates.values(),
          session_config=config,
          **eval_kwargs)
def train():
  if not tf.gfile.Exists(FLAGS.train_log_dir):
    tf.gfile.MakeDirs(FLAGS.train_log_dir)

  g = tf.Graph()
  with g.as_default():
    # If ps_tasks is zero, the local device is used. When using multiple
    # (non-local) replicas, the ReplicaDeviceSetter distributes the variables
    # across the different devices.
    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
      data_tuple = cifar_data_provider.provide_data(
          'train', FLAGS.batch_size, dataset_dir=FLAGS.dataset_dir)
      images, _, one_hot_labels, _, num_classes = data_tuple

      # Define the model:
      with slim.arg_scope(cifar_model.resnet_arg_scope(is_training=True)):
        model = utils.split_and_int(FLAGS.model)
        logits, end_points = cifar_model.resnet(
            images,
            model=model,
            num_classes=num_classes,
            model_type=FLAGS.model_type)

        # Specify the loss function:
        tf.losses.softmax_cross_entropy(
            onehot_labels=one_hot_labels, logits=logits)
        if FLAGS.model_type in ('act', 'act_early_stopping', 'sact'):
          training_utils.add_all_ponder_costs(end_points, weights=FLAGS.tau)
        total_loss = tf.losses.get_total_loss()
        tf.summary.scalar('Total Loss', total_loss)

        metric_map = {}  # summary_utils.flops_metric_map(end_points, False)
        if FLAGS.model_type in ('act', 'act_early_stopping', 'sact'):
          metric_map.update(summary_utils.act_metric_map(end_points, False))
        for name, value in metric_map.iteritems():
          tf.summary.scalar(name, value)

        if FLAGS.model_type == 'sact':
          summary_utils.add_heatmaps_image_summary(end_points)

        init_fn = training_utils.finetuning_init_fn(FLAGS.finetune_path)

        # Specify the optimization scheme:
        global_step = slim.get_or_create_global_step()
        # Original LR schedule
        # boundaries = [40000, 60000, 80000]
        # "Longer" LR schedule
        boundaries = [60000, 75000, 90000]
        boundaries = [tf.constant(x, dtype=tf.int64) for x in boundaries]
        values = [0.1, 0.01, 0.001, 0.0001]
        learning_rate = tf.train.piecewise_constant(global_step, boundaries,
                                                    values)
        tf.summary.scalar('Learning Rate', learning_rate)
        optimizer = tf.train.MomentumOptimizer(learning_rate, 0.9)

        # Set up training.
        train_op = slim.learning.create_train_op(total_loss, optimizer)

        if FLAGS.train_log_dir:
          logdir = FLAGS.train_log_dir
        else:
          logdir = None

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        # Run training.
        slim.learning.train(
            train_op=train_op,
            init_fn=init_fn,
            logdir=logdir,
            master=FLAGS.master,
            number_of_steps=FLAGS.max_number_of_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            session_config=config)
def main(_):
    g = tf.Graph()
    with g.as_default():
        # If ps_tasks is zero, the local device is used. When using multiple
        # (non-local) replicas, the ReplicaDeviceSetter distributes the variables
        # across the different devices.
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               merge_devices=True)):
            data_tuple = imagenet_data_provider.provide_data(
                FLAGS.split_name,
                FLAGS.batch_size,
                dataset_dir=FLAGS.dataset_dir,
                is_training=True,
                image_size=FLAGS.image_size)
            images, labels, examples_per_epoch, num_classes = data_tuple

            # Define the model:
            with slim.arg_scope(
                    imagenet_model.resnet_arg_scope(is_training=True)):
                model = utils.split_and_int(FLAGS.model)
                logits, end_points = imagenet_model.get_network(
                    images, model, num_classes, model_type=FLAGS.model_type)

                # Specify the loss function:
                tf.losses.softmax_cross_entropy(logits,
                                                labels,
                                                label_smoothing=0.1,
                                                weights=1.0)
                if FLAGS.model_type in ('act', 'act_early_stopping', 'sact'):
                    training_utils.add_all_ponder_costs(end_points,
                                                        weights=FLAGS.tau)
                total_loss = tf.losses.get_total_loss()

                # Configure the learning rate using an exponetial decay.
                decay_steps = int(examples_per_epoch / FLAGS.batch_size *
                                  FLAGS.num_epochs_per_decay)

                learning_rate = tf.train.exponential_decay(
                    FLAGS.learning_rate,
                    slim.get_or_create_global_step(),
                    decay_steps,
                    FLAGS.learning_rate_decay_factor,
                    staircase=True)

                opt = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)

                init_fn = training_utils.finetuning_init_fn(
                    FLAGS.finetune_path)

                train_tensor = slim.learning.create_train_op(
                    total_loss,
                    optimizer=opt,
                    update_ops=tf.get_collection(tf.GraphKeys.UPDATE_OPS))

                # Summaries:
                tf.summary.scalar('losses/Total Loss', total_loss)
                tf.summary.scalar('training/Learning Rate', learning_rate)

                metric_map = {
                }  # summary_utils.flops_metric_map(end_points, False)
                if FLAGS.model_type in ('act', 'act_early_stopping', 'sact'):
                    metric_map.update(
                        summary_utils.act_metric_map(end_points, False))
                for name, value in metric_map.iteritems():
                    tf.summary.scalar(name, value)

                if FLAGS.model_type == 'sact':
                    summary_utils.add_heatmaps_image_summary(end_points,
                                                             border=10)

                startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps

                slim.learning.train(
                    train_tensor,
                    init_fn=init_fn,
                    logdir=FLAGS.train_log_dir,
                    master=FLAGS.master,
                    is_chief=(FLAGS.task == 0),
                    startup_delay_steps=startup_delay_steps,
                    save_summaries_secs=FLAGS.save_summaries_secs,
                    save_interval_secs=FLAGS.save_interval_secs)