示例#1
0
def reload_teacher(session, hparams):
    """Reload a saved model to use as teacher in self-distillation."""
    with tf.variable_scope('teacher', reuse=tf.AUTO_REUSE):
        inputs = tf.placeholder('float', [None, 32, 32, 3])
        scopes = setup_arg_scopes(is_training=False, hparams=hparams)
        with tf.variable_scope('model', reuse=tf.AUTO_REUSE):
            with nested(*scopes):
                if hparams.model_name == 'pyramid_net':
                    logits, _ = build_shake_drop_model(inputs,
                                                       num_classes=10,
                                                       is_training=False)
                elif hparams.model_name == 'wrn':
                    logits, _ = build_wrn_model(inputs,
                                                num_classes=10,
                                                hparams=hparams)
                elif hparams.model_name == 'shake_shake':
                    logits, _ = build_shake_shake_model(inputs,
                                                        num_classes=10,
                                                        hparams=hparams,
                                                        is_training=False)
                else:
                    print(
                        f'unrecognized hparams.model_name: {hparams.model_name}'
                    )
                    assert 0
    ckpt = tf.train.latest_checkpoint(
        os.path.join(hparams.teacher_model, 'model'))
    # Map each variable name in the checkpoint to the variable name to restore
    scopedict = {}
    myvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='teacher')
    for var in myvars:
        scopedict[var.name[8:-2]] = var
    saver = tf.train.Saver(var_list=scopedict)
    saver.restore(session, ckpt)
    return logits, inputs
示例#2
0
def build_model(inputs, num_classes, is_training, hparams):
  """Constructs the vision model being trained/evaled.

  Args:
    inputs: input features/images being fed to the image model build built.
    num_classes: number of output classes being predicted.
    is_training: is the model training or not.
    hparams: additional hyperparameters associated with the image model.

  Returns:
    The logits of the image model.
  """
  scopes = setup_arg_scopes(is_training)
  with contextlib.ExitStack() as stack:
    tuple(stack.enter_context(cm) for cm in scopes)
    if hparams.model_name == 'pyramid_net':
      logits = build_shake_drop_model(
          inputs, num_classes, is_training)
    elif hparams.model_name == 'wrn':
      logits = build_wrn_model(
          inputs, num_classes, hparams.wrn_size)
    elif hparams.model_name == 'shake_shake':
      logits = build_shake_shake_model(
          inputs, num_classes, hparams, is_training)
  return logits
def build_model(inputs, num_classes, is_training, hparams):
    """Constructs the vision model being trained/evaled.

  Args:
    inputs: input features/images being fed to the image model build built.
    num_classes: number of output classes being predicted.
    is_training: is the model training or not.
    hparams: additional hyperparameters associated with the image model.

  Returns:
    The logits of the image model.
  """
    scopes = helper_utils.setup_arg_scopes(is_training, hparams)
    with helper_utils.nested(*scopes):
        if hparams.model_name == 'pyramid_net':
            logits, hiddens = build_shake_drop_model(inputs, num_classes,
                                                     is_training)
        elif hparams.model_name == 'wrn':
            logits, hiddens = build_wrn_model(inputs, num_classes, hparams)
        elif hparams.model_name == 'shake_shake':
            logits, hiddens = build_shake_shake_model(inputs, num_classes,
                                                      hparams, is_training)
        else:
            print(f'unrecognized hparams.model_name: {hparams.model_name}')
            assert 0
    return logits, hiddens
示例#4
0
def build_model(inputs, num_classes, is_training, hparams):
  """Constructs the vision model being trained/evaled.

  Args:
    inputs: input features/images being fed to the image model build built.
    num_classes: number of output classes being predicted.
    is_training: is the model training or not.
    hparams: additional hyperparameters associated with the image model.

  Returns:
    The logits of the image model.
  """
  scopes = setup_arg_scopes(is_training)
  #with contextlib.nested(*scopes):
  
  ## test
  
  try:
    from contextlib import nested  # Python 2
  except ImportError:
    from contextlib import ExitStack, contextmanager

    @contextmanager
    def nested(*contexts):
        """
        Reimplementation of nested in python 3.
        """
        with ExitStack() as stack:
            for ctx in contexts:
                stack.enter_context(ctx)
            yield contexts
  
  ##
  
  
  with nested(*scopes):
    if hparams.model_name == 'pyramid_net':
      logits = build_shake_drop_model(
          inputs, num_classes, is_training)
    elif hparams.model_name == 'wrn':
      logits = build_wrn_model(
          inputs, num_classes, hparams.wrn_size)
    elif hparams.model_name == 'shake_shake':
      logits = build_shake_shake_model(
          inputs, num_classes, hparams, is_training)
  return logits
示例#5
0
def build_model(inputs, num_classes, is_training, hparams):
  """Constructs the vision model being trained/evaled.

  Args:
    inputs: input features/images being fed to the image model build built.
    num_classes: number of output classes being predicted.
    is_training: is the model training or not.
    hparams: additional hyperparameters associated with the image model.

  Returns:
    The logits of the image model.
  """
  scopes = setup_arg_scopes(is_training)
  with contextlib.nested(*scopes):
    if hparams.model_name == 'pyramid_net':
      logits = build_shake_drop_model(
          inputs, num_classes, is_training)
    elif hparams.model_name == 'wrn':
      logits = build_wrn_model(
          inputs, num_classes, hparams.wrn_size)
    elif hparams.model_name == 'shake_shake':
      logits = build_shake_shake_model(
          inputs, num_classes, hparams, is_training)
  return logits
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  if not os.path.isdir(prefix):
    os.makedirs(prefix)
  if FLAGS.dsname == 'cifar10':
    hparams = tf.contrib.training.HParams(
        train_size=50000,
        validation_size=0,
        eval_test=True,
        dataset='cifar10',
        data_path='./cifar10_data/',
        extra_dataset='cifar10_1',
        use_batchnorm=1,
        use_fixup=0,
        use_gamma_swish=0)
    if FLAGS.modelname == 'wrn_32':
      setattr(hparams, 'model_name', 'wrn')
      hparams.add_hparam('wrn_size', 32)
    elif FLAGS.modelname == 'wrn_160':
      setattr(hparams, 'model_name', 'wrn')
      hparams.add_hparam('wrn_size', 160)
    elif FLAGS.modelname == 'shake_shake_32':
      setattr(hparams, 'model_name', 'shake_shake')
      hparams.add_hparam('shake_shake_widen_factor', 2)
    elif FLAGS.modelname == 'shake_shake_96':
      setattr(hparams, 'model_name', 'shake_shake')
      hparams.add_hparam('shake_shake_widen_factor', 6)
    elif FLAGS.modelname == 'shake_shake_112':
      setattr(hparams, 'model_name', 'shake_shake')
      hparams.add_hparam('shake_shake_widen_factor', 7)
    elif FLAGS.modelname == 'pyramid_net':
      setattr(hparams, 'model_name', 'pyramid_net')
      hparams.batch_size = 64
    (all_images, all_labels, test_images, test_labels, extra_test_images,
     extra_test_labels) = data_utils_cifar.load_cifar(hparams)
    images = test_images
    labels = test_labels
    if FLAGS.split == 'train':
      images = all_images
      labels = all_labels
    elif FLAGS.split == 'extra':
      images = extra_test_images
      labels = extra_test_labels
    images1 = images[np.argmax(labels, axis=-1) == FLAGS.class_1, Ellipsis]
    labels1 = labels[np.argmax(labels, axis=-1) == FLAGS.class_1, Ellipsis]

    g = tf.Graph()
    with g.as_default():
      inputs = tf.placeholder('float', [None, 32, 32, 3])
      scopes = helper_utils.setup_arg_scopes(is_training=False, hparams=hparams)
      with tf.variable_scope('model', reuse=tf.AUTO_REUSE):
        with helper_utils.nested(*scopes):
          if hparams.model_name == 'pyramid_net':
            logits, hiddens = build_shake_drop_model(
                inputs, num_classes=10, is_training=False)
          elif hparams.model_name == 'wrn':
            logits, hiddens = build_wrn_model(
                inputs, num_classes=10, hparams=hparams)
          elif hparams.model_name == 'shake_shake':
            logits, hiddens = build_shake_shake_model(
                inputs, num_classes=10, hparams=hparams, is_training=False)
          else:
            print(f'unrecognized hparams.model_name: {hparams.model_name}')
            assert 0

    sess = tf.InteractiveSession(graph=g)
    if FLAGS.ckpt_num is None:
      ckpt = tf.train.latest_checkpoint(os.path.join(FLAGS.dirname, 'model'))
    else:
      ckpt = os.path.join(FLAGS.dirname, 'model',
                          'modelckpt.ckpt-' + str(FLAGS.ckpt_num))
    saver = tf.train.Saver()
    saver.restore(sess, ckpt)
    model = lambda imgs: logits.eval(feed_dict={inputs: imgs})

    if FLAGS.class_2 != FLAGS.class_1:
      images2 = images[np.argmax(labels, axis=-1) == FLAGS.class_2, Ellipsis]
      labels2 = labels[np.argmax(labels, axis=-1) == FLAGS.class_2, Ellipsis]
      spectra, max_grads, max_prob_dists, logit_dists, logits_list = aggregate_interp(
          dataset=(images1, labels1),
          model=model,
          projection=None,
          dataset2=(images2, labels2),
          numpy=True)
    else:
      spectra, max_grads, max_prob_dists, logit_dists, logits_list = aggregate_interp(
          dataset=(images1, labels1),
          model=model,
          projection=None,
          dataset2=None,
          numpy=True)
  else:
    logging.warn('unsupported dataset')
    assert False

  # Save the outputs to files
  filename = FLAGS.modelname + '_' + str(
      FLAGS.class_1) + FLAGS.interp_type + str(FLAGS.sampling_distance) + str(
          FLAGS.class_2) + '_' + str(FLAGS.num_pairs)
  if FLAGS.dsname == 'cifar10':
    filename = filename + '_dir_' + FLAGS.dirname.replace('/', '.')
  if FLAGS.ckpt_num is not None:
    filename = filename + '_ckpt' + str(FLAGS.ckpt_num)
  npsave(filename + '_logitdists.npz', logit_dists)
  npsave(filename + '_logitslist.npz', logits_list)