def main(_):
    g = tf.Graph()
    with g.as_default():
        data_tuple = imagenet_data_provider.provide_data(
            FLAGS.split_name,
            FLAGS.batch_size,
            dataset_dir=FLAGS.dataset_dir,
            is_training=False,
            image_size=FLAGS.image_size)
        images, one_hot_labels, examples_per_epoch, num_classes = data_tuple

        # Define the model:
        with slim.arg_scope(
                imagenet_model.resnet_arg_scope(is_training=False)):
            model = utils.split_and_int(FLAGS.model)
            logits, end_points = imagenet_model.get_network(
                images, model, num_classes, model_type=FLAGS.model_type)

            predictions = tf.argmax(end_points['predictions'], 1)

            # Define the metrics:
            labels = tf.argmax(one_hot_labels, 1)
            metric_map = {
                'eval/Accuracy':
                tf.contrib.metrics.streaming_accuracy(predictions, labels),
                'eval/Recall@5':
                tf.contrib.metrics.streaming_sparse_recall_at_k(
                    end_points['predictions'], tf.expand_dims(labels, 1), 5),
            }
            metric_map.update(summary_utils.flops_metric_map(end_points, True))
            if FLAGS.model_type in ['act', 'act_early_stopping', 'sact']:
                metric_map.update(
                    summary_utils.act_metric_map(end_points, True))

            names_to_values, names_to_updates = tf.contrib.metrics.aggregate_metric_map(
                metric_map)

            for name, value in names_to_values.iteritems():
                summ = tf.summary.scalar(name, value, collections=[])
                summ = tf.Print(summ, [value], name)
                tf.add_to_collection(tf.GraphKeys.SUMMARIES, summ)

            if FLAGS.model_type == 'sact':
                summary_utils.add_heatmaps_image_summary(end_points, border=10)

            # This ensures that we make a single pass over all of the data.
            num_batches = math.ceil(FLAGS.num_examples /
                                    float(FLAGS.batch_size))

            if not FLAGS.evaluate_once:
                eval_function = slim.evaluation.evaluation_loop
                checkpoint_path = FLAGS.checkpoint_dir
                kwargs = {'eval_interval_secs': FLAGS.eval_interval_secs}
            else:
                eval_function = slim.evaluation.evaluate_once
                checkpoint_path = tf.train.latest_checkpoint(
                    FLAGS.checkpoint_dir)
                assert checkpoint_path is not None
                kwargs = {}

            eval_function(FLAGS.master,
                          checkpoint_path,
                          logdir=FLAGS.eval_dir,
                          num_evals=num_batches,
                          eval_op=names_to_updates.values(),
                          **kwargs)
예제 #2
0
def train():
  if not tf.gfile.Exists(FLAGS.train_log_dir):
    tf.gfile.MakeDirs(FLAGS.train_log_dir)

  g = tf.Graph()
  with g.as_default():
    # If ps_tasks is zero, the local device is used. When using multiple
    # (non-local) replicas, the ReplicaDeviceSetter distributes the variables
    # across the different devices.
    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
      data_tuple = cifar_data_provider.provide_data(
          'train', FLAGS.batch_size, dataset_dir=FLAGS.dataset_dir)
      images, _, one_hot_labels, _, num_classes = data_tuple

      # Define the model:
      with slim.arg_scope(cifar_model.resnet_arg_scope(is_training=True)):
        model = utils.split_and_int(FLAGS.model)
        logits, end_points = cifar_model.resnet(
            images,
            model=model,
            num_classes=num_classes,
            model_type=FLAGS.model_type)

        # Specify the loss function:
        tf.losses.softmax_cross_entropy(
            onehot_labels=one_hot_labels, logits=logits)
        if FLAGS.model_type in ('act', 'act_early_stopping', 'sact'):
          training_utils.add_all_ponder_costs(end_points, weights=FLAGS.tau)
        total_loss = tf.losses.get_total_loss()
        tf.summary.scalar('Total Loss', total_loss)

        metric_map = {}  # summary_utils.flops_metric_map(end_points, False)
        if FLAGS.model_type in ('act', 'act_early_stopping', 'sact'):
          metric_map.update(summary_utils.act_metric_map(end_points, False))
        for name, value in metric_map.iteritems():
          tf.summary.scalar(name, value)

        if FLAGS.model_type == 'sact':
          summary_utils.add_heatmaps_image_summary(end_points)

        init_fn = training_utils.finetuning_init_fn(FLAGS.finetune_path)

        # Specify the optimization scheme:
        global_step = slim.get_or_create_global_step()
        # Original LR schedule
        # boundaries = [40000, 60000, 80000]
        # "Longer" LR schedule
        boundaries = [60000, 75000, 90000]
        boundaries = [tf.constant(x, dtype=tf.int64) for x in boundaries]
        values = [0.1, 0.01, 0.001, 0.0001]
        learning_rate = tf.train.piecewise_constant(global_step, boundaries,
                                                    values)
        tf.summary.scalar('Learning Rate', learning_rate)
        optimizer = tf.train.MomentumOptimizer(learning_rate, 0.9)

        # Set up training.
        train_op = slim.learning.create_train_op(total_loss, optimizer)

        if FLAGS.train_log_dir:
          logdir = FLAGS.train_log_dir
        else:
          logdir = None

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        # Run training.
        slim.learning.train(
            train_op=train_op,
            init_fn=init_fn,
            logdir=logdir,
            master=FLAGS.master,
            number_of_steps=FLAGS.max_number_of_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            session_config=config)
예제 #3
0
def evaluate():
  g = tf.Graph()
  with g.as_default():
    data_tuple = cifar_data_provider.provide_data(FLAGS.split_name,
                                                  FLAGS.eval_batch_size,
                                                  dataset_dir=FLAGS.dataset_dir)
    images, _, one_hot_labels, num_samples, num_classes = data_tuple

    # Define the model:
    with slim.arg_scope(cifar_model.resnet_arg_scope(is_training=False)):
      model = utils.split_and_int(FLAGS.model)
      logits, end_points = cifar_model.resnet(
          images,
          model=model,
          num_classes=num_classes,
          model_type=FLAGS.model_type)

      predictions = tf.argmax(logits, 1)

      tf.losses.softmax_cross_entropy(
          onehot_labels=one_hot_labels, logits=logits)
      if FLAGS.model_type in ('act', 'act_early_stopping', 'sact'):
        training_utils.add_all_ponder_costs(end_points, weights=FLAGS.tau)

      loss = tf.losses.get_total_loss()

      # Define the metrics:
      labels = tf.argmax(one_hot_labels, 1)
      metric_map = {
          'eval/Accuracy':
                tf.contrib.metrics.streaming_accuracy(predictions, labels),
          'eval/Mean Loss':
                tf.contrib.metrics.streaming_mean(loss),
      }
      metric_map.update(summary_utils.flops_metric_map(end_points, True))
      if FLAGS.model_type in ('act', 'act_early_stopping', 'sact'):
        metric_map.update(summary_utils.act_metric_map(end_points, True))
      names_to_values, names_to_updates = tf.contrib.metrics.aggregate_metric_map(
          metric_map)

      for name, value in names_to_values.iteritems():
        summ = tf.summary.scalar(name, value, collections=[])
        summ = tf.Print(summ, [value], name)
        tf.add_to_collection(tf.GraphKeys.SUMMARIES, summ)

      if FLAGS.model_type == 'sact':
        summary_utils.add_heatmaps_image_summary(end_points)

      # This ensures that we make a single pass over all of the data.
      num_batches = math.ceil(num_samples / float(FLAGS.eval_batch_size))

      if not FLAGS.evaluate_once:
        eval_function = slim.evaluation.evaluation_loop
        checkpoint_path = FLAGS.checkpoint_dir
        eval_kwargs = {'eval_interval_secs': FLAGS.eval_interval_secs}
      else:
        eval_function = slim.evaluation.evaluate_once
        checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
        assert checkpoint_path is not None
        eval_kwargs = {}

      config = tf.ConfigProto()
      config.gpu_options.allow_growth = True

      eval_function(
          FLAGS.master,
          checkpoint_path,
          logdir=FLAGS.eval_dir,
          num_evals=num_batches,
          eval_op=names_to_updates.values(),
          session_config=config,
          **eval_kwargs)
def main(_):
    g = tf.Graph()
    with g.as_default():
        # If ps_tasks is zero, the local device is used. When using multiple
        # (non-local) replicas, the ReplicaDeviceSetter distributes the variables
        # across the different devices.
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               merge_devices=True)):
            data_tuple = imagenet_data_provider.provide_data(
                FLAGS.split_name,
                FLAGS.batch_size,
                dataset_dir=FLAGS.dataset_dir,
                is_training=True,
                image_size=FLAGS.image_size)
            images, labels, examples_per_epoch, num_classes = data_tuple

            # Define the model:
            with slim.arg_scope(
                    imagenet_model.resnet_arg_scope(is_training=True)):
                model = utils.split_and_int(FLAGS.model)
                logits, end_points = imagenet_model.get_network(
                    images, model, num_classes, model_type=FLAGS.model_type)

                # Specify the loss function:
                tf.losses.softmax_cross_entropy(logits,
                                                labels,
                                                label_smoothing=0.1,
                                                weights=1.0)
                if FLAGS.model_type in ('act', 'act_early_stopping', 'sact'):
                    training_utils.add_all_ponder_costs(end_points,
                                                        weights=FLAGS.tau)
                total_loss = tf.losses.get_total_loss()

                # Configure the learning rate using an exponetial decay.
                decay_steps = int(examples_per_epoch / FLAGS.batch_size *
                                  FLAGS.num_epochs_per_decay)

                learning_rate = tf.train.exponential_decay(
                    FLAGS.learning_rate,
                    slim.get_or_create_global_step(),
                    decay_steps,
                    FLAGS.learning_rate_decay_factor,
                    staircase=True)

                opt = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)

                init_fn = training_utils.finetuning_init_fn(
                    FLAGS.finetune_path)

                train_tensor = slim.learning.create_train_op(
                    total_loss,
                    optimizer=opt,
                    update_ops=tf.get_collection(tf.GraphKeys.UPDATE_OPS))

                # Summaries:
                tf.summary.scalar('losses/Total Loss', total_loss)
                tf.summary.scalar('training/Learning Rate', learning_rate)

                metric_map = {
                }  # summary_utils.flops_metric_map(end_points, False)
                if FLAGS.model_type in ('act', 'act_early_stopping', 'sact'):
                    metric_map.update(
                        summary_utils.act_metric_map(end_points, False))
                for name, value in metric_map.iteritems():
                    tf.summary.scalar(name, value)

                if FLAGS.model_type == 'sact':
                    summary_utils.add_heatmaps_image_summary(end_points,
                                                             border=10)

                startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps

                slim.learning.train(
                    train_tensor,
                    init_fn=init_fn,
                    logdir=FLAGS.train_log_dir,
                    master=FLAGS.master,
                    is_chief=(FLAGS.task == 0),
                    startup_delay_steps=startup_delay_steps,
                    save_summaries_secs=FLAGS.save_summaries_secs,
                    save_interval_secs=FLAGS.save_interval_secs)